diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index bc0f1a3269..0000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,180 +0,0 @@ -version: 2 -jobs: - run_tests: - docker: - - image: python:3.7 - environment: - TERM: dumb - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Create a virtualenv - command: | - mkdir -p /tmp/venv/openfisca_core - python -m venv /tmp/venv/openfisca_core - echo "source /tmp/venv/openfisca_core/bin/activate" >> $BASH_ENV - - - run: - name: Install dependencies - command: | - make install - # pip install --editable git+https://github.com/openfisca/country-template.git@BRANCH_NAME#egg=OpenFisca-Country-Template # use a specific branch of OpenFisca-Country-Template - # pip install --editable git+https://github.com/openfisca/extension-template.git@BRANCH_NAME#egg=OpenFisca-Extension-Template # use a specific branch of OpenFisca-Extension-Template - - - save_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - paths: - - /tmp/venv/openfisca_core - - - run: - name: Run Core tests - command: env PYTEST_ADDOPTS="--exitfirst" make test - - - run: - name: Check NumPy typing against latest 3 minor versions - command: for i in {1..3}; do VERSION=$(.circleci/get-numpy-version.py prev) && pip install numpy==$VERSION && make check-types; done - - - persist_to_workspace: - root: . - paths: - - .coverage - - - run: - name: Run Country Template tests - command: | - COUNTRY_TEMPLATE_PATH=`python -c "import openfisca_country_template; print(openfisca_country_template.CountryTaxBenefitSystem().get_package_metadata()['location'])"` - openfisca test $COUNTRY_TEMPLATE_PATH/openfisca_country_template/tests/ - - test_docs: - docker: - - image: python:3.7 - environment: - TERM: dumb - - steps: - - checkout - - - run: - name: Checkout docs - command: make test-doc-checkout branch=$CIRCLE_BRANCH - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - restore_cache: - key: v1-py3-docs-{{ .Branch }}-{{ checksum "doc/requirements.txt" }} - - - run: - name: Create a virtualenv - command: | - mkdir -p /tmp/venv/openfisca_doc - python -m venv /tmp/venv/openfisca_doc - echo "source /tmp/venv/openfisca_doc/bin/activate" >> $BASH_ENV - - - run: - name: Install dependencies - command: make test-doc-install - - - save_cache: - key: v1-py3-docs-{{ .Branch }}-{{ checksum "doc/requirements.txt" }} - paths: - - /tmp/venv/openfisca_doc - - - run: - name: Run doc tests - command: make test-doc-build - - - check_version: - docker: - - image: python:3.7 - - steps: - - checkout - - - run: - name: Check version number has been properly updated - command: | - git fetch - .circleci/is-version-number-acceptable.sh - - submit_coverage: - docker: - - image: python:3.7 - - steps: - - checkout - - - attach_workspace: - at: . - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Submit coverage to Coveralls - command: | - source /tmp/venv/openfisca_core/bin/activate - pip install coveralls - coveralls - - - save_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - paths: - - /tmp/venv/openfisca_core - - deploy: - docker: - - image: python:3.7 - environment: - PYPI_USERNAME: openfisca-bot - # PYPI_PASSWORD: this value is set in CircleCI's web interface; do not set it here, it is a secret! - - steps: - - checkout - - - restore_cache: - key: v1-py3-{{ .Branch }}-{{ checksum "setup.py" }} - - - run: - name: Check for functional changes - command: if ! .circleci/has-functional-changes.sh ; then circleci step halt ; fi - - - run: - name: Upload a Python package to Pypi - command: | - source /tmp/venv/openfisca_core/bin/activate - .circleci/publish-python-package.sh - - - run: - name: Publish a git tag - command: .circleci/publish-git-tag.sh - - - run: - name: Update doc - command: | - curl -X POST --header "Content-Type: application/json" -d '{"branch":"master"}' https://circleci.com/api/v1.1/project/github/openfisca/openfisca-doc/build?circle-token=$CIRCLE_TOKEN - -workflows: - version: 2 - build_and_deploy: - jobs: - - run_tests - - test_docs - - check_version - - submit_coverage: - requires: - - run_tests - - deploy: - requires: - - run_tests - - test_docs - - check_version - filters: - branches: - only: master diff --git a/.circleci/get-numpy-version.py b/.circleci/get-numpy-version.py deleted file mode 100755 index 64cb68532e..0000000000 --- a/.circleci/get-numpy-version.py +++ /dev/null @@ -1,38 +0,0 @@ -#! /usr/bin/env python - -from __future__ import annotations - -import os -import sys -import typing -from packaging import version -from typing import NoReturn, Union - -import numpy - -if typing.TYPE_CHECKING: - from packaging.version import LegacyVersion, Version - - -def prev() -> NoReturn: - release = _installed().release - - if release is None: - sys.exit(os.EX_DATAERR) - - major, minor, _ = release - - if minor == 0: - sys.exit(os.EX_DATAERR) - - minor -= 1 - print(f"{major}.{minor}.0") # noqa: T001 - sys.exit(os.EX_OK) - - -def _installed() -> Union[LegacyVersion, Version]: - return version.parse(numpy.__version__) - - -if __name__ == "__main__": - globals()[sys.argv[1]]() diff --git a/.circleci/publish-git-tag.sh b/.circleci/publish-git-tag.sh deleted file mode 100755 index 4450357cbc..0000000000 --- a/.circleci/publish-git-tag.sh +++ /dev/null @@ -1,4 +0,0 @@ -#! /usr/bin/env bash - -git tag `python setup.py --version` -git push --tags # update the repository version diff --git a/.circleci/publish-python-package.sh b/.circleci/publish-python-package.sh deleted file mode 100755 index 8d331bd946..0000000000 --- a/.circleci/publish-python-package.sh +++ /dev/null @@ -1,4 +0,0 @@ -#! /usr/bin/env bash - -python setup.py bdist_wheel # build this package in the dist directory -twine upload dist/* --username $PYPI_USERNAME --password $PYPI_PASSWORD # publish diff --git a/.conda/README.md b/.conda/README.md new file mode 100644 index 0000000000..ac0d2c2be5 --- /dev/null +++ b/.conda/README.md @@ -0,0 +1,37 @@ +# Publish OpenFisca-Core to conda + +We use two systems to publish to conda: +- A fully automatic in OpenFisca-Core CI that publishes to an `openfisca` channel. See below for more information. +- A more complex in Conda-Forge CI, that publishes to [Conda-Forge](https://conda-forge.org). See this [YouTube video](https://www.youtube.com/watch?v=N2XwK9BkJpA) as an introduction to Conda-Forge, and [openfisca-core-feedstock repository](https://github.com/openfisca/openfisca-core-feedstock) for the project publishing process on Conda-Forge. + +We use both channels. With conda-forge users get an easier way to install and use openfisca-core: conda-forge is the default channel in Anaconda and it allows for publishing packages that depend on openfisca-core to conda-forge. + + +## Automatic upload + +The CI automatically uploads the PyPi package; see the `.github/workflow.yml`, step `publish-to-conda`. + +## Manual actions for first time publishing + +- Create an account on https://anaconda.org. +- Create a token on https://anaconda.org/openfisca/settings/access with `Allow write access to the API site`. Warning, it expires on 2023/01/13. + +- Put the token in a CI environment variable named `ANACONDA_TOKEN`. + + +## Manual actions to test before CI + +Everything is done by the CI but if you want to test it locally, here is how to do it. + +Do the following in the project root folder: + +- Auto-update `.conda/meta.yaml` with last infos from pypi by running: + - `python .github/get_pypi_info.py -p OpenFisca-Core` + +- Build package: + - `conda install -c anaconda conda-build anaconda-client` (`conda-build` to build the package and [anaconda-client](https://github.com/Anaconda-Platform/anaconda-client) to push the package to anaconda.org) + - `conda build -c conda-forge .conda` + + - Upload the package to Anaconda.org, but DON'T do it if you don't want to publish your locally built package as official openfisca-core library: + - `anaconda login` + - `anaconda upload openfisca-core--py_0.tar.bz2` diff --git a/.conda/openfisca-core/conda_build_config.yaml b/.conda/openfisca-core/conda_build_config.yaml new file mode 100644 index 0000000000..02754f3894 --- /dev/null +++ b/.conda/openfisca-core/conda_build_config.yaml @@ -0,0 +1,9 @@ +numpy: +- 1.24 +- 1.25 +- 1.26 + +python: +- 3.9 +- 3.10 +- 3.11 diff --git a/.conda/openfisca-core/meta.yaml b/.conda/openfisca-core/meta.yaml new file mode 100644 index 0000000000..be31e84b95 --- /dev/null +++ b/.conda/openfisca-core/meta.yaml @@ -0,0 +1,86 @@ +############################################################################### +## File for Anaconda.org +## It use Jinja2 templating code to retreive information from setup.py +############################################################################### + +{% set name = "OpenFisca-Core" %} +{% set data = load_setup_py_data() %} +{% set version = data.get('version') %} + +package: + name: {{ name|lower }} + version: {{ version }} + +source: + path: ../.. + +build: + noarch: python + number: 0 + script: "{{ PYTHON }} -m pip install . -vv" + entry_points: + - openfisca = openfisca_core.scripts.openfisca_command:main + - openfisca-run-test = openfisca_core.scripts.openfisca_command:main + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + {% for req in data['install_requires'] %} + {% if not req.startswith('numpy') %} + - {{ req }} + {% endif %} + {% endfor %} + +test: + imports: + - openfisca_core + - openfisca_core.commons + +outputs: + - name: openfisca-core + + - name: openfisca-core-api + build: + noarch: python + requirements: + host: + - numpy + - python + run: + - numpy + - python + {% for req in data['extras_require']['web-api'] %} + - {{ req }} + {% endfor %} + - {{ pin_subpackage('openfisca-core', exact=True) }} + + - name: openfisca-core-dev + build: + noarch: python + requirements: + host: + - numpy + - python + run: + - numpy + - python + {% for req in data['extras_require']['dev'] %} + - {{ req }} + {% endfor %} + - {{ pin_subpackage('openfisca-core-api', exact=True) }} + +about: + home: https://openfisca.org + license_family: AGPL + license: AGPL-3.0-only + license_file: LICENSE + summary: "A versatile microsimulation free software" + doc_url: https://openfisca.org + dev_url: https://github.com/openfisca/openfisca-core/ + description: This package contains the core features of OpenFisca, which are meant to be used by country packages such as OpenFisca-Country-Template. diff --git a/.conda/openfisca-country-template/recipe.yaml b/.conda/openfisca-country-template/recipe.yaml new file mode 100644 index 0000000000..7b75cf22c2 --- /dev/null +++ b/.conda/openfisca-country-template/recipe.yaml @@ -0,0 +1,42 @@ +schema_version: 1 + +context: + name: openfisca-country-template + version: 7.1.5 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/openfisca_country_template-${{ version }}.tar.gz + sha256: b2f2ac9945d9ccad467aed0925bd82f7f4d5ce4e96b212324cd071b8bee46914 + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + - openfisca-core >=42,<43 + +tests: +- python: + imports: + - openfisca_country_template + +about: + summary: OpenFisca Rules as Code model for Country-Template. + license: AGPL-3.0 + license_file: LICENSE + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/openfisca-country-template/variants.yaml b/.conda/openfisca-country-template/variants.yaml new file mode 100644 index 0000000000..64e0aaf0f1 --- /dev/null +++ b/.conda/openfisca-country-template/variants.yaml @@ -0,0 +1,7 @@ +numpy: +- "1.26" + +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.conda/openfisca-extension-template/recipe.yaml b/.conda/openfisca-extension-template/recipe.yaml new file mode 100644 index 0000000000..03e53d5dd0 --- /dev/null +++ b/.conda/openfisca-extension-template/recipe.yaml @@ -0,0 +1,43 @@ +schema_version: 1 + +context: + name: openfisca-extension-template + version: 1.3.15 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/openfisca_extension_template-${{ version }}.tar.gz + sha256: e16ee9cbefdd5e9ddc1c2c0e12bcd74307c8cb1be55353b3b2788d64a90a5df9 + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - numpy + - pip + - python + - setuptools >=61.0 + run: + - numpy + - python + - openfisca-country-template >=7,<8 + +tests: +- python: + imports: + - openfisca_extension_template + +about: + summary: An OpenFisca extension that adds some variables to an already-existing + tax and benefit system. + license: AGPL-3.0 + license_file: LICENSE + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/openfisca-extension-template/variants.yaml b/.conda/openfisca-extension-template/variants.yaml new file mode 100644 index 0000000000..64e0aaf0f1 --- /dev/null +++ b/.conda/openfisca-extension-template/variants.yaml @@ -0,0 +1,7 @@ +numpy: +- "1.26" + +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.conda/pylint-per-file-ignores/recipe.yaml b/.conda/pylint-per-file-ignores/recipe.yaml new file mode 100644 index 0000000000..4a573982f8 --- /dev/null +++ b/.conda/pylint-per-file-ignores/recipe.yaml @@ -0,0 +1,41 @@ +schema_version: 1 + +context: + name: pylint-per-file-ignores + version: 1.3.2 + +package: + name: ${{ name|lower }} + version: ${{ version }} + +source: + url: https://pypi.org/packages/source/${{ name[0] }}/${{ name }}/pylint_per_file_ignores-${{ version }}.tar.gz + sha256: 3c641f69c316770749a8a353556504dae7469541cdaef38e195fe2228841451e + +build: + noarch: python + script: pip install . -v + +requirements: + host: + - python + - poetry-core >=1.0.0 + - pip + run: + - pylint >=3.3.1,<4.0 + - python + - tomli >=2.0.1,<3.0.0 + +tests: +- python: + imports: + - pylint_per_file_ignores + +about: + summary: A pylint plugin to ignore error codes per file. + license: MIT + homepage: https://github.com/christopherpickering/pylint-per-file-ignores.git + +extra: + recipe-maintainers: + - bonjourmauko diff --git a/.conda/pylint-per-file-ignores/variants.yaml b/.conda/pylint-per-file-ignores/variants.yaml new file mode 100644 index 0000000000..ab419e422e --- /dev/null +++ b/.conda/pylint-per-file-ignores/variants.yaml @@ -0,0 +1,4 @@ +python: +- "3.9" +- "3.10" +- "3.11" diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..71eaf02d67 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 +updates: +- package-ecosystem: pip + directory: / + schedule: + interval: monthly + labels: + - kind:dependencies diff --git a/.github/get_pypi_info.py b/.github/get_pypi_info.py new file mode 100644 index 0000000000..70013fbe98 --- /dev/null +++ b/.github/get_pypi_info.py @@ -0,0 +1,83 @@ +"""Script to get information needed by .conda/meta.yaml from PyPi JSON API. + +This script use get_info to get the info (yes !) and replace_in_file to +write them into .conda/meta.yaml. +Sample call: +python3 .github/get_pypi_info.py -p OpenFisca-Core +""" + +import argparse + +import requests + + +def get_info(package_name: str = "") -> dict: + """Get minimal information needed by .conda/meta.yaml from PyPi JSON API. + + ::package_name:: Name of package to get infos from. + ::return:: A dict with last_version, url and sha256 + """ + if package_name == "": + msg = "Package name not provided." + raise ValueError(msg) + url = f"https://pypi.org/pypi/{package_name}/json" + print(f"Calling {url}") # noqa: T201 + resp = requests.get(url) + if resp.status_code != 200: + msg = f"ERROR calling PyPI ({url}) : {resp}" + raise Exception(msg) + resp = resp.json() + version = resp["info"]["version"] + + for v in resp["releases"][version]: + # Find packagetype=="sdist" to get source code in .tar.gz + if v["packagetype"] == "sdist": + return { + "last_version": version, + "url": v["url"], + "sha256": v["digests"]["sha256"], + } + return {} + + +def replace_in_file(filepath: str, info: dict) -> None: + """Replace placeholder in meta.yaml by their values. + + ::filepath:: Path to meta.yaml, with filename. + ::info:: Dict with information to populate. + """ + with open(filepath, encoding="utf-8") as fin: + meta = fin.read() + # Replace with info from PyPi + meta = meta.replace("PYPI_VERSION", info["last_version"]) + meta = meta.replace("PYPI_URL", info["url"]) + meta = meta.replace("PYPI_SHA256", info["sha256"]) + with open(filepath, "w", encoding="utf-8") as fout: + fout.write(meta) + print(f"File {filepath} has been updated with info from PyPi.") # noqa: T201 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", + "--package", + type=str, + default="", + required=True, + help="The name of the package", + ) + parser.add_argument( + "-f", + "--filename", + type=str, + default=".conda/openfisca-core/meta.yaml", + help="Path to meta.yaml, with filename", + ) + args = parser.parse_args() + info = get_info(args.package) + print( # noqa: T201 + "Information of the last published PyPi package :", + info["last_version"], + ) + replace_in_file(args.filename, info) diff --git a/.circleci/has-functional-changes.sh b/.github/has-functional-changes.sh similarity index 91% rename from .circleci/has-functional-changes.sh rename to .github/has-functional-changes.sh index 049a94d6cd..bf1270989a 100755 --- a/.circleci/has-functional-changes.sh +++ b/.github/has-functional-changes.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -IGNORE_DIFF_ON="README.md CONTRIBUTING.md Makefile .gitignore LICENSE* .circleci/* .github/* tests/*" +IGNORE_DIFF_ON="README.md CONTRIBUTING.md Makefile .gitignore LICENSE* .github/* tests/* openfisca_tasks/*.mk tasks/*.mk" last_tagged_commit=`git describe --tags --abbrev=0 --first-parent` # --first-parent ensures we don't follow tags not published in master through an unlikely intermediary merge commit diff --git a/.circleci/is-version-number-acceptable.sh b/.github/is-version-number-acceptable.sh similarity index 95% rename from .circleci/is-version-number-acceptable.sh rename to .github/is-version-number-acceptable.sh index ae370e2a17..0f704a93fe 100755 --- a/.circleci/is-version-number-acceptable.sh +++ b/.github/is-version-number-acceptable.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $CIRCLE_BRANCH == master ]] +if [[ ${GITHUB_REF#refs/heads/} == master ]] then echo "No need for a version check on master." exit 0 diff --git a/.github/workflows/_before-conda.yaml b/.github/workflows/_before-conda.yaml new file mode 100644 index 0000000000..7528a6a1c2 --- /dev/null +++ b/.github/workflows/_before-conda.yaml @@ -0,0 +1,109 @@ +name: Setup conda + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + +defaults: + run: + shell: bash -l {0} + +jobs: + setup: + runs-on: ${{ inputs.os }} + name: conda-setup-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + # To colorize output of make tasks. + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache conda env + uses: actions/cache@v4 + with: + path: | + /usr/share/miniconda/envs/openfisca + ~/.conda/envs/openfisca + .env.yaml + key: conda-env-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + restore-keys: conda-env-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + id: cache-env + + - name: Cache conda deps + uses: actions/cache@v4 + with: + path: ~/conda_pkgs_dir + key: conda-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + restore-keys: conda-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + id: cache-deps + + - name: Cache release + uses: actions/cache@v4 + with: + path: ~/conda-rel + key: conda-release-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Setup conda + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + miniforge-version: latest + python-version: ${{ inputs.python }} + use-mamba: true + if: steps.cache-env.outputs.cache-hit != 'true' + + - name: Install dependencies + run: mamba install boa rattler-build anaconda-client + if: steps.cache-env.outputs.cache-hit != 'true' + + - name: Update conda & dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + environment-file: .env.yaml + miniforge-version: latest + use-mamba: true + if: steps.cache-env.outputs.cache-hit == 'true' + + - name: Build pylint plugin package + run: | + rattler-build build \ + --recipe .conda/pylint-per-file-ignores \ + --output-dir ~/conda-rel + + - name: Build core package + run: | + conda mambabuild .conda/openfisca-core \ + --use-local \ + --no-anaconda-upload \ + --output-folder ~/conda-rel \ + --numpy ${{ inputs.numpy }} \ + --python ${{ inputs.python }} + + - name: Build country template package + run: | + rattler-build build \ + --recipe .conda/openfisca-country-template \ + --output-dir ~/conda-rel \ + + - name: Build extension template package + run: | + rattler-build build \ + --recipe .conda/openfisca-extension-template \ + --output-dir ~/conda-rel + + - name: Export env + run: mamba env export --name openfisca > .env.yaml diff --git a/.github/workflows/_before-pip.yaml b/.github/workflows/_before-pip.yaml new file mode 100644 index 0000000000..02554419c8 --- /dev/null +++ b/.github/workflows/_before-pip.yaml @@ -0,0 +1,103 @@ +name: Setup package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + + activate_command: + required: true + type: string + +jobs: + deps: + runs-on: ${{ inputs.os }} + name: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + # To colorize output of make tasks. + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + restore-keys: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + + - name: Install dependencies + run: | + python -m venv venv + ${{ inputs.activate_command }} + make install-deps install-dist + pip install numpy==${{ inputs.numpy }} + + build: + runs-on: ${{ inputs.os }} + needs: [deps] + name: pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache build + uses: actions/cache@v4 + with: + path: venv/**/[Oo]pen[Ff]isca* + key: pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + restore-keys: | + pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}- + pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}- + + - name: Cache release + uses: actions/cache@v4 + with: + path: dist + key: pip-release-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Build package + run: | + ${{ inputs.activate_command }} + make install-test clean build diff --git a/.github/workflows/_lint-pip.yaml b/.github/workflows/_lint-pip.yaml new file mode 100644 index 0000000000..e994f473e3 --- /dev/null +++ b/.github/workflows/_lint-pip.yaml @@ -0,0 +1,57 @@ +name: Lint package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + + activate_command: + required: true + type: string + +jobs: + lint: + runs-on: ${{ inputs.os }} + name: pip-lint-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + TERM: xterm-256color # To colorize output of make tasks. + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Lint doc + run: | + ${{ inputs.activate_command }} + make clean check-syntax-errors lint-doc + + - name: Lint styles + run: | + ${{ inputs.activate_command }} + make clean check-syntax-errors check-style diff --git a/.github/workflows/_test-conda.yaml b/.github/workflows/_test-conda.yaml new file mode 100644 index 0000000000..fab88ac1df --- /dev/null +++ b/.github/workflows/_test-conda.yaml @@ -0,0 +1,76 @@ +name: Test conda package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + +defaults: + run: + shell: bash -l {0} + +jobs: + test: + runs-on: ${{ inputs.os }} + name: conda-test-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + TERM: xterm-256color # To colorize output of make tasks. + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache conda env + uses: actions/cache@v4 + with: + path: | + /usr/share/miniconda/envs/openfisca + ~/.conda/envs/openfisca + .env.yaml + key: conda-env-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache conda deps + uses: actions/cache@v4 + with: + path: ~/conda_pkgs_dir + key: conda-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache release + uses: actions/cache@v4 + with: + path: ~/conda-rel + key: conda-release-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Update conda & dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + environment-file: .env.yaml + miniforge-version: latest + use-mamba: true + + - name: Install packages + run: | + mamba install --channel file:///home/runner/conda-rel \ + openfisca-core-dev \ + openfisca-country-template \ + openfisca-extension-template + + - name: Run core tests + run: make test-core + + - name: Run country tests + run: make test-country + + - name: Run extension tests + run: make test-extension diff --git a/.github/workflows/_test-pip.yaml b/.github/workflows/_test-pip.yaml new file mode 100644 index 0000000000..704c7fbd17 --- /dev/null +++ b/.github/workflows/_test-pip.yaml @@ -0,0 +1,72 @@ +name: Test package + +on: + workflow_call: + inputs: + os: + required: true + type: string + + numpy: + required: true + type: string + + python: + required: true + type: string + + activate_command: + required: true + type: string + +jobs: + test: + runs-on: ${{ inputs.os }} + name: pip-test-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TERM: xterm-256color # To colorize output of make tasks. + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Use zstd for faster cache restore (windows) + if: ${{ startsWith(inputs.os, 'windows') }} + shell: cmd + run: echo C:\Program Files\Git\usr\bin>>"%GITHUB_PATH%" + + - name: Cache dependencies + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }} + + - name: Cache build + uses: actions/cache@v4 + with: + path: venv/**/[Oo]pen[Ff]isca* + key: pip-build-${{ inputs.os }}-np${{ inputs.numpy }}-py${{ inputs.python }}-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Run Openfisca Core tests + run: | + ${{ inputs.activate_command }} + make test-core + python -m coveralls --service=github + + - name: Run Country Template tests + if: ${{ startsWith(inputs.os, 'ubuntu') }} + run: | + ${{ inputs.activate_command }} + make test-country + + - name: Run Extension Template tests + if: ${{ startsWith(inputs.os, 'ubuntu') }} + run: | + ${{ inputs.activate_command }} + make test-extension diff --git a/.github/workflows/_version.yaml b/.github/workflows/_version.yaml new file mode 100644 index 0000000000..27c4737a4f --- /dev/null +++ b/.github/workflows/_version.yaml @@ -0,0 +1,38 @@ +name: Check version + +on: + workflow_call: + inputs: + os: + required: true + type: string + + python: + required: true + type: string + +jobs: + # The idea behind these dependencies is that we want to give feedback to + # contributors on the version number only after they have passed all tests, + # so they don't have to do it twice after changes happened to the main branch + # during the time they took to fix the tests. + check-version: + runs-on: ${{ inputs.os }} + env: + # To colorize output of make tasks. + TERM: xterm-256color + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + # Fetch all the tags + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python }} + + - name: Check version number has been properly updated + run: ${GITHUB_WORKSPACE}/.github/is-version-number-acceptable.sh diff --git a/.github/workflows/merge.yaml b/.github/workflows/merge.yaml new file mode 100644 index 0000000000..5c2b4c791d --- /dev/null +++ b/.github/workflows/merge.yaml @@ -0,0 +1,247 @@ +name: OpenFisca-Core / Deploy package to PyPi & Conda + +on: + push: + branches: [master] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +jobs: + setup-pip: + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + # Patch version must be specified to avoid any cache confusion, since + # the cache key depends on the full Python version. If left unspecified, + # different patch versions could be allocated between jobs, and any + # such difference would lead to a cache not found error. + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_before-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + setup-conda: + uses: ./.github/workflows/_before-conda.yaml + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + test-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_test-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + test-conda: + uses: ./.github/workflows/_test-conda.yaml + needs: [setup-conda] + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + lint-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + numpy: [1.24.2] + python: [3.11.9, 3.9.13] + uses: ./.github/workflows/_lint-pip.yaml + with: + os: ubuntu-22.04 + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: source venv/bin/activate + + check-version: + needs: [test-pip, test-conda, lint-pip] + uses: ./.github/workflows/_version.yaml + with: + os: ubuntu-22.04 + python: 3.9.13 + + # GitHub Actions does not have a halt job option, to stop from deploying if + # no functional changes were found. We build a separate job to substitute the + # halt option. The `deploy` job is dependent on the output of the + # `check-for-functional-changes`job. + check-for-functional-changes: + runs-on: ubuntu-22.04 + # Last job to run + needs: [check-version] + outputs: + status: ${{ steps.stop-early.outputs.status }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9.13 + + - id: stop-early + # The `check-for-functional-changes` job should always succeed regardless + # of the `has-functional-changes` script's exit code. Consequently, we do + # not use that exit code to trigger deploy, but rather a dedicated output + # variable `status`, to avoid a job failure if the exit code is different + # from 0. Conversely, if the job fails the entire workflow would be + # marked as `failed` which is disturbing for contributors. + run: if "${GITHUB_WORKSPACE}/.github/has-functional-changes.sh" ; then echo + "::set-output name=status::success" ; fi + + publish-to-pypi: + runs-on: ubuntu-22.04 + needs: [check-for-functional-changes] + if: needs.check-for-functional-changes.outputs.status == 'success' + env: + PYPI_USERNAME: __token__ + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_OPENFISCA_BOT }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: 3.9.13 + + - name: Cache deps + uses: actions/cache@v4 + with: + path: venv + key: pip-deps-ubuntu-22.04-np1.24.2-py3.9.13-${{ hashFiles('setup.py') }} + + - name: Cache build + uses: actions/cache@v4 + with: + path: venv/**/[oO]pen[fF]isca* + key: pip-build-ubuntu-22.04-np1.24.2-py3.9.13-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Cache release + uses: actions/cache@v4 + with: + path: dist + key: pip-release-ubuntu-22.04-np1.24.2-py3.9.13-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Upload package to PyPi + run: | + source venv/bin/activate + make publish + + - name: Update doc + run: | + curl -L \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.OPENFISCADOC_BOT_ACCESS_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/openfisca/openfisca-doc/actions/workflows/deploy.yaml/dispatches \ + -d '{"ref":"main"}' + + publish-to-conda: + runs-on: ubuntu-22.04 + needs: [publish-to-pypi] + defaults: + run: + shell: bash -l {0} + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Cache conda env + uses: actions/cache@v4 + with: + path: | + /usr/share/miniconda/envs/openfisca + ~/.conda/envs/openfisca + .env.yaml + key: conda-env-ubuntu-22.04-np1.26.4-py3.10.6-${{ hashFiles('setup.py') }} + + - name: Cache conda deps + uses: actions/cache@v4 + with: + path: ~/conda_pkgs_dir + key: conda-deps-ubuntu-22.04-np1.26.4-py3.10.6-${{ hashFiles('setup.py') }} + + - name: Cache release + uses: actions/cache@v4 + with: + path: ~/conda-rel + key: conda-release-ubuntu-22.04-np1.26.4-py3.10.6-${{ hashFiles('setup.py') }}-${{ github.sha }} + + - name: Update conda & dependencies + uses: conda-incubator/setup-miniconda@v3 + with: + activate-environment: openfisca + environment-file: .env.yaml + miniforge-version: latest + use-mamba: true + + - name: Publish to conda + shell: bash -l {0} + run: | + anaconda upload ~/conda-rel/noarch/openfisca-core-* \ + --token ${{ secrets.ANACONDA_TOKEN }} + --user openfisca + --force + + test-on-windows: + runs-on: windows-2019 + needs: [publish-to-conda] + defaults: + run: + shell: bash -l {0} + + steps: + - uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + # See GHA Windows + # https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json + python-version: 3.10.6 + channels: conda-forge + activate-environment: true + + - name: Checkout + uses: actions/checkout@v4 + + - name: Install with conda + shell: bash -l {0} + run: conda install -c openfisca openfisca-core + + - name: Test openfisca + shell: bash -l {0} + run: openfisca --help diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml new file mode 100644 index 0000000000..7bee48c81c --- /dev/null +++ b/.github/workflows/push.yaml @@ -0,0 +1,89 @@ +name: OpenFisca-Core / Pull request review + +on: + pull_request: + types: [assigned, opened, reopened, synchronize, ready_for_review] + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +jobs: + setup-pip: + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + # Patch version must be specified to avoid any cache confusion, since + # the cache key depends on the full Python version. If left unspecified, + # different patch versions could be allocated between jobs, and any + # such difference would lead to a cache not found error. + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_before-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + setup-conda: + uses: ./.github/workflows/_before-conda.yaml + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + test-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + os: [ubuntu-22.04, windows-2019] + numpy: [1.26.4, 1.24.2] + python: [3.11.9, 3.9.13] + include: + - os: ubuntu-22.04 + activate_command: source venv/bin/activate + - os: windows-2019 + activate_command: .\venv\Scripts\activate + uses: ./.github/workflows/_test-pip.yaml + with: + os: ${{ matrix.os }} + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: ${{ matrix.activate_command }} + + test-conda: + uses: ./.github/workflows/_test-conda.yaml + needs: [setup-conda] + with: + os: ubuntu-22.04 + numpy: 1.26.4 + python: 3.10.6 + + lint-pip: + needs: [setup-pip] + strategy: + fail-fast: true + matrix: + numpy: [1.24.2] + python: [3.11.9, 3.9.13] + uses: ./.github/workflows/_lint-pip.yaml + with: + os: ubuntu-22.04 + numpy: ${{ matrix.numpy }} + python: ${{ matrix.python }} + activate_command: source venv/bin/activate + + check-version: + needs: [test-pip, test-conda, lint-pip] + uses: ./.github/workflows/_version.yaml + with: + os: ubuntu-22.04 + python: 3.9.13 diff --git a/.gitignore b/.gitignore index 4b56efc6da..00d9f19fb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,22 @@ -.venv -.project -.spyderproject -.pydevproject -.vscode -.settings/ -.vscode/ -build/ -dist/ -doc/ *.egg-info *.mo *.pyc *~ -/cover -/.coverage -/tags -.tags* +.coverage +.mypy_cache +.mypy_cache* .noseids +.project +.pydevproject .pytest_cache -.mypy_cache +.settings +.spyderproject +.tags* +.venv +.vscode +build +cover +dist +doc performance.json +tags diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a5a3f99ea..bf2962fd85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,19 +1,614 @@ # Changelog -## 35.6.0 [#1033](https://github.com/openfisca/openfisca-core/pull/1033) +### 42.0.4 [#1257](https://github.com/openfisca/openfisca-core/pull/1257) + +#### Technical changes + +- Fix conda test and publish +- Add matrix testing to CI + - Now it tests lower and upper bounds of python and numpy versions + +### 42.0.3 [#1234](https://github.com/openfisca/openfisca-core/pull/1234) + +#### Technical changes + +- Add matrix testing to CI + - Now it tests lower and upper bounds of python and numpy versions + +> Note: Version `42.0.3` has been unpublished as was deployed by mistake. +> Please use versions `42.0.4` and subsequents. + +### 42.0.2 [#1256](https://github.com/openfisca/openfisca-core/pull/1256) + +#### Documentation + +- Fix bad indent + +### 42.0.1 [#1253](https://github.com/openfisca/openfisca-core/pull/1253) + +#### Documentation + +- Fix documentation of `entities` + +# 42.0.0 [#1223](https://github.com/openfisca/openfisca-core/pull/1223) + +#### Breaking changes + +- Changes to `eternity` instants and periods + - Eternity instants are now `` instead of + `` + - Eternity periods are now `, -1))>` + instead of `, inf))>` + - The reason is to avoid mixing data types: `inf` is a float, periods and + instants are integers. Mixed data types make memory optimisations impossible. + - Migration should be straightforward. If you have a test that checks for + `inf`, you should update it to check for `-1` or use the `is_eternal` method. +- `periods.instant` no longer returns `None` + - Now, it raises `periods.InstantError` + +#### New features + +- Introduce `Instant.eternity()` + - This behaviour was duplicated across + - Now it is encapsulated in a single method +- Introduce `Instant.is_eternal` and `Period.is_eternal` + - These methods check if the instant or period are eternity (`bool`). +- Now `periods.instant` parses also ISO calendar strings (weeks) + - For instance, `2022-W01` is now a valid input + +#### Technical changes + +- Update `pendulum` +- Reduce code complexity +- Remove run-time type-checks +- Add typing to the periods module + +### 41.5.7 [#1225](https://github.com/openfisca/openfisca-core/pull/1225) + +#### Technical changes + +- Refactor & test `eval_expression` + +### 41.5.6 [#1185](https://github.com/openfisca/openfisca-core/pull/1185) + +#### Technical changes + +- Remove pre Python 3.9 syntax. + +### 41.5.5 [#1220](https://github.com/openfisca/openfisca-core/pull/1220) + +#### Technical changes + +- Fix doc & type definitions in the entities module + +### 41.5.4 [#1219](https://github.com/openfisca/openfisca-core/pull/1219) + +#### Technical changes + +- Fix doc & type definitions in the commons module + +### 41.5.3 [#1218](https://github.com/openfisca/openfisca-core/pull/1218) + +#### Technical changes + +- Fix `flake8` doc linting: + - Add format "google" + - Fix per-file skips +- Fix failing lints + +### 41.5.2 [#1217](https://github.com/openfisca/openfisca-core/pull/1217) + +#### Technical changes + +- Fix styles by applying `isort`. +- Add a `isort` dry-run check to `make lint` + +### 41.5.1 [#1216](https://github.com/openfisca/openfisca-core/pull/1216) + +#### Technical changes + +- Fix styles by applying `black`. +- Add a `black` dry-run check to `make lint` + +## 41.5.0 [#1212](https://github.com/openfisca/openfisca-core/pull/1212) + +#### New features + +- Introduce `VectorialAsofDateParameterNodeAtInstant` + - It is a parameter node of the legislation at a given instant which has been vectorized along some date. + - Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a `numpy.datetime64` vector + +### 41.4.7 [#1211](https://github.com/openfisca/openfisca-core/pull/1211) + +#### Technical changes + +- Update documentation continuous deployment method to reflect OpenFisca-Doc [process updates](https://github.com/openfisca/openfisca-doc/pull/308) + +### 41.4.6 [#1210](https://github.com/openfisca/openfisca-core/pull/1210) + +#### Technical changes + +- Abide by OpenAPI v3.0.0 instead of v3.1.0 + - Drop support for `propertyNames` in `Values` definition + +### 41.4.5 [#1209](https://github.com/openfisca/openfisca-core/pull/1209) + +#### Technical changes + +- Support loading metadata from both `setup.py` and `pyproject.toml` package description files. + +### ~41.4.4~ [#1208](https://github.com/openfisca/openfisca-core/pull/1208) + +_Unpublished due to introduced backwards incompatibilities._ + +#### Technical changes + +- Adapt testing pipeline to Country Template [v7](https://github.com/openfisca/country-template/pull/139). + +### 41.4.3 [#1206](https://github.com/openfisca/openfisca-core/pull/1206) + +#### Technical changes + +- Increase spiral and cycle tests robustness. + - The current test is ambiguous, as it hides a failure at the first spiral + occurrence (from 2017 to 2016). + +### 41.4.2 [#1203](https://github.com/openfisca/openfisca-core/pull/1203) + +#### Technical changes + +- Changes the Pypi's deployment authentication way to use token API following Pypi's 2FA enforcement starting 2024/01/01. + +### 41.4.1 [#1202](https://github.com/openfisca/openfisca-core/pull/1202) + +#### Technical changes + +- Check that entities are fully specified when expanding over axes. + +## 41.4.0 [#1197](https://github.com/openfisca/openfisca-core/pull/1197) + +#### New features + +- Add `entities.find_role()` to find roles by key and `max`. + +#### Technical changes + +- Document `projectors.get_projector_from_shortcut()`. + +## 41.3.0 [#1200](https://github.com/openfisca/openfisca-core/pull/1200) + +> As `TracingParameterNodeAtInstant` is a wrapper for `ParameterNodeAtInstant` +> which allows iteration and the use of `contains`, it was not possible +> to use those on a `TracingParameterNodeAtInstant` + +#### New features + +- Allows iterations on `TracingParameterNodeAtInstant` +- Allows keyword `contains` on `TracingParameterNodeAtInstant` + +## 41.2.0 [#1199](https://github.com/openfisca/openfisca-core/pull/1199) + +#### Technical changes + +- Fix `openfisca-core` Web API error triggered by `Gunicorn` < 22.0. + - Bump `Gunicorn` major revision to fix error on Web API. + Source: https://github.com/benoitc/gunicorn/issues/2564 + +### 41.1.2 [#1192](https://github.com/openfisca/openfisca-core/pull/1192) + +#### Technical changes + +- Add tests to `entities`. + +### 41.1.1 [#1186](https://github.com/openfisca/openfisca-core/pull/1186) + +#### Technical changes + +- Skip type-checking tasks + - Before their definition was commented out but still run with `make test` + - Now they're skipped but not commented, which is needed to fix the + underlying issues + +## 41.1.0 [#1195](https://github.com/openfisca/openfisca-core/pull/1195) + +#### Technical changes + +- Make `Role` explicitly hashable. +- Details: + - By introducing `__eq__`, naturally `Role` became unhashable, because + equality was calculated based on a property of `Role` + (`role.key == another_role.key`), and no longer structurally + (`"1" == "1"`). + - This changeset removes `__eq__`, as `Role` is being used downstream as a + hashable object, and adds a test to ensure `Role`'s hashability. + +### 41.0.2 [#1194](https://github.com/openfisca/openfisca-core/pull/1194) + +#### Technical changes + +- Add `__hash__` method to `Role`. + +### 41.0.1 [#1187](https://github.com/openfisca/openfisca-core/pull/1187) + +#### Technical changes + +- Document `Role`. + +# 41.0.0 [#1189](https://github.com/openfisca/openfisca-core/pull/1189) + +#### Breaking changes + +- `Variable.get_introspection_data` no longer has parameters nor calling functions + +The Web API was very prone to crashing, timeouting at startup because of the time consuming python file parsing to generate documentation displayed for instance in the Legislation Explorer. + +## 40.1.0 [#1174](https://github.com/openfisca/openfisca-core/pull/1174) + +#### New Features + +- Allows for dispatching and dividing inputs over a broader range. + - For example, divide a monthly variable by week. + +### 40.0.1 [#1184](https://github.com/openfisca/openfisca-core/pull/1184) + +#### Technical changes + +- Require numpy < 1.25 because of memory leak detected in OpenFisca-France. + +# 40.0.0 [#1181](https://github.com/openfisca/openfisca-core/pull/1181) + +#### Breaking changes + +- Upgrade every dependencies to its latest version. +- Upgrade to Python >= 3.9 + +Note: Checks on mypy typings are disabled, because they cause generate of errors that we were not able to fix easily. + +# 39.0.0 [#1181](https://github.com/openfisca/openfisca-core/pull/1181) + +#### Breaking changes + +- Upgrade every dependencies to their latest versions. +- Upgrade to Python >= 3.9 + +Main changes, that may require some code changes in country packages: +- numpy +- pytest +- Flask + +### 38.0.4 [#1182](https://github.com/openfisca/openfisca-core/pull/1182) + +#### Technical changes + +- Method `_get_tax_benefit_system()` of class `YamlItem` in file `openfisca_core/tools/test_runner.py` will now clone the TBS when applying reforms to avoid running tests with previously reformed TBS. + +### 38.0.3 [#1179](https://github.com/openfisca/openfisca-core/pull/1179) + +#### Bug fix + +- Do not install dependencies outside the `setup.py` + - Dependencies installed outside the `setup.py` are not taken into account by + `pip`'s dependency resolver. + - In case of conflicting transient dependencies, the last library installed + will "impose" its dependency version. + - This makes the installation and build of the library non-deterministic and + prone to unforeseen bugs caused by external changes in dependencies' + versions. + +#### Note + +A definite way to solve this issue is to clearly separate library dependencies +(with a `virtualenv`) and a universal dependency installer for CI requirements +(like `pipx`), taking care of: + +- Always running tests inside the `virtualenv` (for example with `nox`). +- Always building outside of the `virtualenv` (for example with `poetry` + installed by `pipx`). + +Moreover, it is indeed even better to have a lock file for dependencies, +using `pip freeze`) or with tools providing such features (`pipenv`, etc.). + +### 38.0.2 [#1178](https://github.com/openfisca/openfisca-core/pull/1178) + +#### Technical changes + +- Remove use of `importlib_metadata`. + +### 38.0.1 - + +> Note: Version `38.0.1` has been unpublished as was deployed by mistake. +> Please use versions `38.0.2` and subsequents. + + +# 38.0.0 [#989](https://github.com/openfisca/openfisca-core/pull/989) + +> Note: Version `38.0.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### New Features + +- Upgrade OpenAPI specification of the API to v3 from Swagger v2. +- Continuously validate OpenAPI specification. + +#### Breaking changes + +- Drop support for OpenAPI specification v2 and prior. + - Users relying on OpenAPI v2 can use [Swagger Converter](https://converter.swagger.io/api/convert?url=OAS2_YAML_OR_JSON_URL) to migrate ([example](https://web.archive.org/web/20221103230822/https://converter.swagger.io/api/convert?url=https://api.demo.openfisca.org/latest/spec)). + +### 37.0.2 [#1170](https://github.com/openfisca/openfisca-core/pull/1170) + +> Note: Version `37.0.2` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Technical changes + +- Always import numpy + +### 37.0.1 [#1169](https://github.com/openfisca/openfisca-core/pull/1169) + +> Note: Version `37.0.1` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Technical changes + +- Unify casing of NumPy. + +# 37.0.0 [#1142](https://github.com/openfisca/openfisca-core/pull/1142) + +> Note: Version `37.0.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Deprecations + +- In _periods.Instant_: + - Remove `period`, method used to build a `Period` from an `Instant`. + - This method created an upward circular dependency between `Instant` and `Period` causing lots of trouble. + - The functionality is still provided by `periods.period` and the `Period` constructor. + +#### Migration details + +- Replace `some_period.start.period` and similar methods with `Period((unit, some_period.start, 1))`. + +# 36.0.0 [#1149](https://github.com/openfisca/openfisca-core/pull/1162) + +> Note: Version `36.0.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Breaking changes + +- In `ParameterScaleBracket`: + - Remove the `base` attribute + - The attribute's usage was unclear and it was only being used by some French social security variables + +## 35.12.0 [#1160](https://github.com/openfisca/openfisca-core/pull/1160) + +> Note: Version `35.12.0` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### New Features + +- Lighter install by removing test packages from systematic install. + +### 35.11.2 [#1166](https://github.com/openfisca/openfisca-core/pull/1166) + +> Note: Version `35.11.2` has been unpublished as `35.11.1` introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Technical changes + +- Fix Holder's doctests. + +### 35.11.1 [#1165](https://github.com/openfisca/openfisca-core/pull/1165) + +> Note: Version `35.11.1` has been unpublished as it introduced a bug +> preventing users to load a tax-benefit system. Please use versions `38.0.2` +> and subsequents. + +#### Bug fix + +- Fix documentation + - Suppression of some modules broke the documentation build + +## 35.11.0 [#1149](https://github.com/openfisca/openfisca-core/pull/1149) + +#### New Features + +- Introduce variable dependent error margins in YAML tests. + +### 35.10.1 [#1143](https://github.com/openfisca/openfisca-core/pull/1143) + +#### Bug fix + +- Reintroduce support for the ``day`` date unit in `holders.set_input_dispatch_by_period` and `holders. + set_input_divide_by_period` + - Allows for dispatching values per day, for example, to provide a daily (week, fortnight) to an yearly variable. + - Inversely, allows for calculating the daily (week, fortnight) value of a yearly input. + +## 35.10.0 [#1151](https://github.com/openfisca/openfisca-core/pull/1151) + +#### New features + +- Add type hints for all instances of `variable_name` in function declarations. +- Add type hints for some `Simulation` and `Population` properties. + +## 35.9.0 [#1150](https://github.com/openfisca/openfisca-core/pull/1150) + +#### New Features + +- Introduce a maximal depth for computation logs + - Allows for limiting the depth of the computation log chain + +### 35.8.6 [#1145](https://github.com/openfisca/openfisca-core/pull/1145) + +#### Technical changes + +- Removes the automatic documentation build check + - It has been proven difficult to maintain, specifically due _dependency hell_ and a very contrived build workflow. + +### 35.8.5 [#1137](https://github.com/openfisca/openfisca-core/pull/1137) + +#### Technical changes + +- Fix pylint dependency in fresh editable installations + - Ignore pytest requirement, used to collect test cases, if it is not yet installed. + +### 35.8.4 [#1131](https://github.com/openfisca/openfisca-core/pull/1131) + +#### Technical changes + +- Correct some type hints and docstrings. + +### 35.8.3 [#1127](https://github.com/openfisca/openfisca-core/pull/1127) + +#### Technical changes + +- Fix the build for Anaconda in CI. The conda build failed on master because of a replacement in a comment string. + - The _ were removed in the comment to avoid a replace. + +### 35.8.2 [#1128](https://github.com/openfisca/openfisca-core/pull/1128) + +#### Technical changes + +- Remove ambiguous links in docstrings. + +### 35.8.1 [#1105](https://github.com/openfisca/openfisca-core/pull/1105) + +#### Technical changes + +- Add publish to Anaconda in CI. See file .conda/README.md. + +## 35.8.0 [#1114](https://github.com/openfisca/openfisca-core/pull/1114) + +#### New Features + +- Introduce `rate_from_bracket_indice` method on `RateTaxScaleLike` class + - Allows for the determination of the tax rate based on the tax bracket indice + +- Introduce `rate_from_tax_base` method on `RateTaxScaleLike` class + - Allows for the determination of the tax rate based on the tax base + +- Introduce `threshold_from_tax_base` method on `RateTaxScaleLike` class + - Allows for the determination of the lower threshold based on the tax base + +- Add publish openfisca-core library to Anaconda in CI. See file .conda/README.md. + +### 35.7.8 [#1086](https://github.com/openfisca/openfisca-core/pull/1086) + +#### Technical changes + +### 35.7.7 [#1109](https://github.com/openfisca/openfisca-core/pull/1109) + +#### Technical changes + +- Fix `openfisca-core` Web API error triggered by `Flask` dependencies updates + - Bump `Flask` patch revision to fix `cannot import name 'json' from 'itsdangerous'` on Web API. + - Then, fix `MarkupSafe` revision to avoid `cannot import name 'soft_unicode' from 'markupsafe'` error on Web API. + +### 35.7.6 [#1065](https://github.com/openfisca/openfisca-core/pull/1065) + +#### Technical changes + +- Made code compatible with dpath versions >=1.5.0,<3.0.0, instead of >=1.5.0,<2.0.0 + +### 35.7.5 [#1090](https://github.com/openfisca/openfisca-core/pull/1090) + +#### Technical changes + +- Remove calls to deprecated imp module + +### 35.7.4 [#1083](https://github.com/openfisca/openfisca-core/pull/1083) + +#### Technical changes + +- Add GitHub `pull-request` event as a trigger to GitHub Actions workflow + +### 35.7.3 [#1081](https://github.com/openfisca/openfisca-core/pull/1081) + +- Correct error message in case of mis-sized population + +### 35.7.2 [#1057](https://github.com/openfisca/openfisca-core/pull/1057) + +#### Technical changes + +- Switch CI provider from CircleCI to GitHub Actions + +### 35.7.1 [#1075](https://github.com/openfisca/openfisca-core/pull/1075) + +#### Bug fix + +- Fix the collection of OpenFisca-Core tests coverage data + - Tests within `openfisca_core/*` were not run + +## 35.7.0 [#1070](https://github.com/openfisca/openfisca-core/pulls/1070) + +#### New Features + +- Add group population shortcut to containing groups entities + +## 35.6.0 [#1054](https://github.com/openfisca/openfisca-core/pull/1054) #### New Features - Introduce `openfisca_core.types` -#### Bug Fixes +#### Documentation -- Fix doctests of the commons module +- Complete typing of the commons module + +#### Dependencies + +- `nptyping` + - To add backport-support for numpy typing + - Can be removed once lower-bound numpy version is 1.21+ + +- `typing_extensions` + - To add backport-support for `typing.Protocol` and `typing.Literal` + - Can be removed once lower-bound python version is 3.8+ + +### 35.5.5 [#1055](https://github.com/openfisca/openfisca-core/pull/1055) #### Documentation - Complete the documentation of the commons module +### 35.5.4 [#1033](https://github.com/openfisca/openfisca-core/pull/1033) + +#### Bug Fixes + +- Fix doctests of the commons module + +#### Dependencies + +- `darglint`, `flake8-docstrings`, & `pylint` + - For automatic docstring linting & validation. + +### 35.5.3 [#1020](https://github.com/openfisca/openfisca-core/pull/1020) + +#### Technical changes + +- Run openfisca-core & country/extension template tests systematically + +### 35.5.2 [#1048](https://github.com/openfisca/openfisca-core/pull/1048) + +#### Bug fix + +- In _test_yaml.py_: + - Fix yaml tests loading —required for testing against the built version. + +### 35.5.1 [#1046](https://github.com/openfisca/openfisca-core/pull/1046) + +#### Non-technical changes + +- Reorganise `Makefile` into context files (install, test, publish…) +- Colorise `make` tasks and improve messages printed to the user + ## 35.5.0 [#1038](https://github.com/openfisca/openfisca-core/pull/1038) #### New Features @@ -75,7 +670,7 @@ - When libraries do not implement their own types, MyPy provides stubs, or type sheds - Thanks to `__future__.annotations`, those stubs or type sheds are casted to `typing.Any` - Since 1.20.x, NumPy now provides their own type definitions - - The introduction of NumPy 1.20.x in #990 caused one major problem: + - The introduction of NumPy 1.20.x in #990 caused one major problem: - It is general practice to do not import at runtime modules only used for typing purposes, thanks to the `typing.TYPE_CHEKING` variable - The new `numpy.typing` module was being imported at runtime, rendering OpenFisca unusable to all users depending on previous versions of NumPy (1.20.x-) - These changes revert #990 and solve #1009 and #1012 @@ -213,22 +808,22 @@ _Note: this version has been unpublished due to an issue introduced by NumPy upg #### Breaking changes -- Update Numpy version's upper bound to 1.18 - - Numpy 1.18 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) that might be used in openfisca country models. +- Update NumPy version's upper bound to 1.18 + - NumPy 1.18 [expires a list of old deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) that might be used in openfisca country models. #### Migration details -You might need to change your code if any of the [Numpy expired deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) is used in your model formulas. +You might need to change your code if any of the [NumPy expired deprecations](https://numpy.org/devdocs/release/1.18.0-notes.html#expired-deprecations) is used in your model formulas. Here is a subset of the deprecations that you might find in your model with some checks and migration steps (where `np` stands for `numpy`): -* `Removed deprecated support for boolean and empty condition lists in np.select.` - * Before `np.select([], [])` result was `0` (for a `default` argument value set to `0`). +* `Removed deprecated support for boolean and empty condition lists in numpy.select.` + * Before `numpy.select([], [])` result was `0` (for a `default` argument value set to `0`). * Now, we have to check for empty conditions and, return `0` or the defined default argument value when we want to keep the same behavior. * Before, integer conditions where transformed to booleans. - * For example, `np.select([0, 1, 0], ['a', 'b', 'c'])` result was `array('b', dtype=' ``` > And two parameters `parameters.city_tax.z1` and `parameters.city_tax.z2`, they can be dynamically accessed through: > ```py -> zone = np.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) +> zone = numpy.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) > zone_value = parameters.rate._get_at_instant('2015-01-01').single.owner[zone] > ``` > returns @@ -2030,9 +2625,9 @@ class housing_occupancy_status(Variable): ```py holder = simulation.household.get_holder('housing_occupancy_status') # Three possibilities -holder.set_input(period, np.asarray([HousingOccupancyStatus.owner])) -holder.set_input(period, np.asarray(['owner'])) -holder.set_input(period, np.asarray([0])) # Highly not recommanded +holder.set_input(period, numpy.asarray([HousingOccupancyStatus.owner])) +holder.set_input(period, numpy.asarray(['owner'])) +holder.set_input(period, numpy.asarray([0])) # Highly not recommanded ``` - When calculating an Enum variable, the output will be an [EnumArray](https://openfisca.org/doc/openfisca-python-api/enum_array.html#module-openfisca_core.indexed_enums). @@ -2978,7 +3573,7 @@ These breaking changes only concern variable and tax and benefit system **metada ### 4.3.4 -* Fix occasionnal `NaN` creation in `MarginalRateTaxScale.calc` resulting from `0 * np.inf` +* Fix occasionnal `NaN` creation in `MarginalRateTaxScale.calc` resulting from `0 * numpy.inf` ### 4.3.3 diff --git a/Makefile b/Makefile index 1a555ad9d1..6ace90d5f5 100644 --- a/Makefile +++ b/Makefile @@ -1,174 +1,34 @@ -info = $$(tput setaf 6)[i]$$(tput sgr0) -warn = $$(tput setaf 3)[!]$$(tput sgr0) -work = $$(tput setaf 5)[⚙]$$(tput sgr0) -pass = echo $$(tput setaf 2)[✓]$$(tput sgr0) Good work! =\> $$(tput setaf 8)$1$$(tput sgr0)$$(tput setaf 2)passed$$(tput sgr0) $$(tput setaf 1)❤$$(tput sgr0) -help = sed -n "/^$1/ { x ; p ; } ; s/\#\#/$(work)/ ; s/\./.../ ; x" ${MAKEFILE_LIST} -repo = https://github.com/openfisca/openfisca-doc -branch = $(shell git branch --show-current) +include openfisca_tasks/install.mk +include openfisca_tasks/lint.mk +include openfisca_tasks/publish.mk +include openfisca_tasks/serve.mk +include openfisca_tasks/test_code.mk -## Same as `make test`. -all: test - @$(call pass,$@:) - -## Run tests, type and style linters. -test: clean check-syntax-errors check-style check-types test-code - @$(call pass,$@:) - -## Install project dependencies. -install: - @$(call help,$@:) - @pip install --upgrade pip twine wheel - @pip install --editable .[dev] --upgrade --use-deprecated=legacy-resolver - -## Install openfisca-core for deployment and publishing. -build: setup.py - @## This allows us to be sure tests are run against the packaged version - @## of openfisca-core, the same we put in the hands of users and reusers. - @$(call help,$@:) - @python $? bdist_wheel - @find dist -name "*.whl" -exec pip install --force-reinstall {}[dev] \; - -## Uninstall project dependencies. -uninstall: - @$(call help,$@:) - @pip freeze | grep -v "^-e" | sed "s/@.*//" | xargs pip uninstall -y - -## Delete builds and compiled python files. -clean: \ - $(shell ls -d * | grep "build\|dist") \ - $(shell find . -name "*.pyc") - @$(call help,$@:) - @rm -rf $? - -## Compile python files to check for syntax errors. -check-syntax-errors: . - @$(call help,$@:) - @python -m compileall -q $? - @$(call pass,$@:) +## To share info with the user, but no action is needed. +print_info = $$(tput setaf 6)[i]$$(tput sgr0) -## Run linters to check for syntax and style errors. -check-style: \ - check-style-all \ - check-style-doc-commons \ - check-style-doc-entities \ - check-style-doc-indexed_enums \ - check-style-doc-types - @$(call pass,$@:) +## To warn the user of something, but no action is needed. +print_warn = $$(tput setaf 3)[!]$$(tput sgr0) -## Run linters to check for syntax and style errors. -check-style-all: - @$(call help,$@:) - @flake8 `git ls-files | grep "\.py$$"` +## To let the user know where we are in the task pipeline. +print_work = $$(tput setaf 5)[⚙]$$(tput sgr0) -## Run linters to check for syntax and style errors. -check-style-doc-%: - @$(call help,$@:) - @flake8 --select=D101,D102,D103,DAR openfisca_core/$* - @pylint --enable=classes,exceptions,imports,miscellaneous,refactoring --disable=W0201,W0231 openfisca_core/$* +## To let the user know the task in progress succeded. +## The `$1` is a function argument, passed from a task (usually the task name). +print_pass = echo $$(tput setaf 2)[✓]$$(tput sgr0) $$(tput setaf 8)$1$$(tput sgr0)$$(tput setaf 2)passed$$(tput sgr0) $$(tput setaf 1)❤$$(tput sgr0) -## Run static type checkers for type errors. -check-types: \ - check-types-all \ - check-types-strict-commons \ - check-types-strict-entities \ - check-types-strict-indexed_enums \ - check-types-strict-types - @$(call pass,$@:) +## Similar to `print_work`, but this will read the comments above a task, and +## print them to the user at the start of each task. The `$1` is a function +## argument. +print_help = sed -n "/^$1/ { x ; p ; } ; s/\#\#/\r$(print_work)/ ; s/\./…/ ; x" ${MAKEFILE_LIST} -## Run static type checkers for type errors. -check-types-all: - @$(call help,$@:) - @mypy --package openfisca_core --package openfisca_web_api +## Same as `make`. +.DEFAULT_GOAL := all -## Run static type checkers for type errors. -check-types-strict-%: - @$(call help,$@:) - @mypy --cache-dir .mypy_cache-openfisca_core.$* --implicit-reexport --strict --package openfisca_core.$* - -## Run openfisca core & web-api tests. -test-code: - @$(call help,$@:) - @PYTEST_ADDOPTS="${PYTEST_ADDOPTS}" pytest --cov=openfisca_core --cov=openfisca_web_api - @$(call pass,$@:) - -## Check that the current changes do not break the doc. -test-doc: - @## Usage: - @## - @## make test-doc [branch=BRANCH] - @## - @## Examples: - @## - @## # Will check the current branch in openfisca-doc. - @## make test-doc - @## - @## # Will check "test-doc" in openfisca-doc. - @## make test-doc branch=test-doc - @## - @## # Will check "master" if "asdf1234" does not exist. - @## make test-doc branch=asdf1234 - @## - @$(call help,$@:) - @${MAKE} test-doc-checkout - @${MAKE} test-doc-install - @${MAKE} test-doc-build - @$(call pass,$@:) - -## Update the local copy of the doc. -test-doc-checkout: - @$(call help,$@:) - @[ ! -d doc ] && git clone ${repo} doc || : - @cd doc && { \ - git reset --hard ; \ - git fetch --all ; \ - [ "$$(git branch --show-current)" != "master" ] && git checkout master || : ; \ - [ "${branch}" != "master" ] \ - && { \ - { \ - >&2 echo "$(info) Trying to checkout the branch 'openfisca-doc/${branch}'..." ; \ - git branch -D ${branch} 2> /dev/null ; \ - git checkout ${branch} 2> /dev/null ; \ - } \ - && git pull --ff-only origin ${branch} \ - || { \ - >&2 echo "$(warn) The branch 'openfisca-doc/${branch}' was not found, falling back to 'openfisca-doc/master'..." ; \ - >&2 echo "" ; \ - >&2 echo "$(info) This is perfectly normal, one of two things can ensue:" ; \ - >&2 echo "$(info)" ; \ - >&2 echo "$(info) $$(tput setaf 2)[If tests pass]$$(tput sgr0)" ; \ - >&2 echo "$(info) * No further action required on your side..." ; \ - >&2 echo "$(info)" ; \ - >&2 echo "$(info) $$(tput setaf 1)[If tests fail]$$(tput sgr0)" ; \ - >&2 echo "$(info) * Create the branch '${branch}' in 'openfisca-doc'... " ; \ - >&2 echo "$(info) * Push your fixes..." ; \ - >&2 echo "$(info) * Run 'make test-doc' again..." ; \ - >&2 echo "" ; \ - >&2 echo "$(work) Checking out 'openfisca-doc/master'..." ; \ - git pull --ff-only origin master ; \ - } \ - } \ - || git pull --ff-only origin master ; \ - } 1> /dev/null - -## Install doc dependencies. -test-doc-install: - @$(call help,$@:) - @pip install --requirement doc/requirements.txt 1> /dev/null - @pip install --editable .[dev] --upgrade 1> /dev/null - -## Dry-build the doc. -test-doc-build: - @$(call help,$@:) - @sphinx-build -M dummy doc/source doc/build -n -q -W - -## Run code formatters to correct style errors. -format-style: $(shell git ls-files "*.py") - @$(call help,$@:) - @autopep8 $? +## Same as `make test`. +all: test + @$(call print_pass,$@:) -## Serve the openfisca Web API. -api: - @$(call help,$@:) - @openfisca serve \ - --country-package openfisca_country_template \ - --extensions openfisca_extension_template +## Run all lints and tests. +test: clean lint test-code + @$(call print_pass,$@:) diff --git a/README.md b/README.md index 7f253c9114..3974b4908c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ # OpenFisca Core -[![Newsletter](https://img.shields.io/badge/newsletter-subscribe!-informational.svg?style=flat)](mailto:contact%40openfisca.org?subject=Subscribe%20to%20your%20newsletter%20%7C%20S'inscrire%20%C3%A0%20votre%20newsletter&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0AEnvoyez-nous%20cet%20email%20pour%20que%20l'on%20puisse%20vous%20inscrire%20%C3%A0%20la%20newsletter.%20%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20send%20us%20this%20email%2C%20so%20we%20can%20subscribe%20you%20to%20the%20newsletter.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) -[![Twitter](https://img.shields.io/badge/twitter-follow%20us!-9cf.svg?style=flat)](https://twitter.com/intent/follow?screen_name=openfisca) -[![Slack](https://img.shields.io/badge/slack-join%20us!-blueviolet.svg?style=flat)](mailto:contact%40openfisca.org?subject=Join%20you%20on%20Slack%20%7C%20Nous%20rejoindre%20sur%20Slack&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0ARacontez-nous%20un%20peu%20de%20vous%2C%20et%20du%20pourquoi%20de%20votre%20int%C3%A9r%C3%AAt%20de%20rejoindre%20la%20communaut%C3%A9%20OpenFisca%20sur%20Slack.%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AN%E2%80%99oubliez%20pas%20de%20nous%20envoyer%20cet%20email%C2%A0!%20Sinon%2C%20on%20ne%20pourra%20pas%20vous%20contacter%20ni%20vous%20inviter%20sur%20Slack.%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20tell%20us%20a%20bit%20about%20you%20and%20why%20you%20want%20to%20join%20the%20OpenFisca%20community%20on%20Slack.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2.%0A%0ADon't%20forget%20to%20send%20us%20this%20email!%20Otherwise%20we%20won't%20be%20able%20to%20contact%20you%20back%2C%20nor%20invite%20you%20on%20Slack.%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) +[![PyPi Downloads](https://img.shields.io/pypi/dm/openfisca-core?label=pypi%2Fdownloads&style=for-the-badge)](https://pepy.tech/project/openfisca-core) +[![PyPi Version](https://img.shields.io/pypi/v/openfisca-core.svg?label=pypi%2Fversion&style=for-the-badge)](https://pypi.python.org/pypi/openfisca-core) +[![Conda Downloads](https://img.shields.io/conda/dn/conda-forge/openfisca-core?label=conda%2Fdownloads&style=for-the-badge)](https://anaconda.org/conda-forge/openfisca-core) +[![Conda Version](https://img.shields.io/conda/vn/conda-forge/openfisca-core.svg?label=conda%2Fversion&style=for-the-badge)](https://anaconda.org/conda-forge/openfisca-core) -[![CircleCI](https://img.shields.io/circleci/project/github/openfisca/openfisca-core/master.svg?style=flat)](https://circleci.com/gh/openfisca/openfisca-core) -[![Coveralls](https://img.shields.io/coveralls/github/openfisca/openfisca-core/master.svg?style=flat)](https://coveralls.io/github/openfisca/openfisca-core?branch=master) -[![Python](https://img.shields.io/pypi/pyversions/openfisca-core.svg)](https://pypi.python.org/pypi/openfisca-core) -[![PyPi](https://img.shields.io/pypi/v/openfisca-core.svg?style=flat)](https://pypi.python.org/pypi/openfisca-core) +[![Python](https://img.shields.io/pypi/pyversions/openfisca-core.svg?label=python&style=for-the-badge)](https://pypi.python.org/pypi/openfisca-core) +[![Coveralls](https://img.shields.io/coveralls/github/openfisca/openfisca-core/master.svg?label=code%20coverage&style=for-the-badge)](https://coveralls.io/github/openfisca/openfisca-core?branch=master) +[![Contributors](https://img.shields.io/github/contributors/openfisca/openfisca-core.svg?style=for-the-badge)](https://github.com/openfisca/openfisca-core/graphs/contributors) + +[![Newsletter](https://img.shields.io/badge/newsletter-subscribe!-informational.svg?style=for-the-badge)](mailto:contact%40openfisca.org?subject=Subscribe%20to%20your%20newsletter%20%7C%20S'inscrire%20%C3%A0%20votre%20newsletter&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0AEnvoyez-nous%20cet%20email%20pour%20que%20l'on%20puisse%20vous%20inscrire%20%C3%A0%20la%20newsletter.%20%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20send%20us%20this%20email%2C%20so%20we%20can%20subscribe%20you%20to%20the%20newsletter.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) +[![Twitter](https://img.shields.io/badge/twitter-follow%20us!-9cf.svg?style=for-the-badge)](https://twitter.com/intent/follow?screen_name=openfisca) +[![Slack](https://img.shields.io/badge/slack-join%20us!-blueviolet.svg?style=for-the-badge)](mailto:contact%40openfisca.org?subject=Join%20you%20on%20Slack%20%7C%20Nous%20rejoindre%20sur%20Slack&body=%5BEnglish%20version%20below%5D%0A%0ABonjour%2C%0A%0AVotre%C2%A0pr%C3%A9sence%C2%A0ici%C2%A0nous%C2%A0ravit%C2%A0!%20%F0%9F%98%83%0A%0ARacontez-nous%20un%20peu%20de%20vous%2C%20et%20du%20pourquoi%20de%20votre%20int%C3%A9r%C3%AAt%20de%20rejoindre%20la%20communaut%C3%A9%20OpenFisca%20sur%20Slack.%0A%0AAh%C2%A0!%20Et%20si%20vous%20pouviez%20remplir%20ce%20petit%20questionnaire%2C%20%C3%A7a%20serait%20encore%20mieux%C2%A0!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2F45M0VR1TYKD1RGzX2%0A%0AN%E2%80%99oubliez%20pas%20de%20nous%20envoyer%20cet%20email%C2%A0!%20Sinon%2C%20on%20ne%20pourra%20pas%20vous%20contacter%20ni%20vous%20inviter%20sur%20Slack.%0A%0AAmiti%C3%A9%2C%0AL%E2%80%99%C3%A9quipe%20OpenFisca%0A%0A%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%20ENGLISH%20VERSION%20%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%3D%0A%0AHi%2C%20%0A%0AWe're%20glad%20to%20see%20you%20here!%20%F0%9F%98%83%0A%0APlease%20tell%20us%20a%20bit%20about%20you%20and%20why%20you%20want%20to%20join%20the%20OpenFisca%20community%20on%20Slack.%0A%0AAlso%2C%20if%20you%20can%20fill%20out%20this%20short%20survey%2C%20even%20better!%0Ahttps%3A%2F%2Fgoo.gl%2Fforms%2FsOg8K1abhhm441LG2.%0A%0ADon't%20forget%20to%20send%20us%20this%20email!%20Otherwise%20we%20won't%20be%20able%20to%20contact%20you%20back%2C%20nor%20invite%20you%20on%20Slack.%0A%0ACheers%2C%0AThe%20OpenFisca%20Team) [OpenFisca](https://openfisca.org/doc/) is a versatile microsimulation free software. Check the [online documentation](https://openfisca.org/doc/) for more details. @@ -15,27 +19,69 @@ This package contains the core features of OpenFisca, which are meant to be used ## Environment -OpenFisca runs on Python 3.7. More recent versions should work, but are not tested. +OpenFisca runs on Python 3.7. More recent versions should work but are not tested. -OpenFisca also relies strongly on NumPy. Last four minor versions should work, but only latest/stable is tested. +OpenFisca also relies strongly on NumPy. The last four minor versions should work, but only the latest/stable is tested. ## Installation -If you're developing your own country package, you don't need to explicitly install OpenFisca-Core. It just needs to appear [in your package dependencies](https://github.com/openfisca/openfisca-france/blob/18.2.1/setup.py#L53). +If you're developing your own country package, you don't need to explicitly install OpenFisca-Core. It just needs to appear [in your package dependencies](https://github.com/openfisca/openfisca-france/blob/100.0.0/setup.py#L60). +If you want to contribute to OpenFisca-Core itself, welcome! +To install it locally you can use one of these two options: +* [conda](https://docs.conda.io/en/latest/) package manager that we recommend for Windows operating system users, +* or standard Python [pip](https://packaging.python.org/en/latest/key_projects/#pip) package manager. + +### Installing `openfisca-core` with `pip` -If you want to contribute to OpenFisca-Core itself, welcome! To install it locally in development mode run the following commands: +This installation requires [Python](https://www.python.org/downloads/) 3.7+ and [GIT](https://git-scm.com) installations. + +To install `openfisca-core` locally in development mode run the following commands in a shell terminal: ```bash git clone https://github.com/openfisca/openfisca-core.git cd openfisca-core python3 -m venv .venv source .venv/bin/activate -pip install -U pip -pip install --editable .[dev] --use-deprecated=legacy-resolver +make install-deps install-edit ``` +### Installing `openfisca-core` with `conda` + +Since `openfisca-core` version [35.7.7](https://anaconda.org/conda-forge/openfisca-core), you could use `conda` to install OpenFisca-Core. + +Conda is the easiest way to use OpenFisca under Windows as by installing Anaconda you will get: +- Python +- The package manager [Anaconda.org](https://docs.anaconda.com/anacondaorg/user-guide/) +- A virtual environment manager : [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) +- A GUI [Anaconda Navigator](https://docs.anaconda.com/anaconda/navigator/index.html) if you choose to install the full [Anaconda](https://www.anaconda.com/products/individual) + +If you are familiar with the command line you could use [Miniconda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/windows.html), which needs very much less disk space than Anaconda. + +After installing conda, run these commands in an `Anaconda Powershell Prompt`: +- `conda create --name openfisca python=3.7` to create an `openfisca` environment. +- `conda activate openfisca` to use your new environment. + +Then, choose one of the following options according to your use case: +- `conda install -c conda-forge openfisca-core` for default dependencies, +- or `conda install -c conda-forge openfisca-core-api` if you want the Web API part, +- or `conda install -c conda-forge -c openfisca openfisca-core-dev` if you want all the dependencies needed to contribute to the project. + +For information on how we publish to conda-forge, see [openfisca-core-feedstock](https://github.com/openfisca/openfisca-core-feedstock/blob/master/recipe/README.md). + ## Testing +Install the test dependencies: + +``` +make install-deps install-edit install-test +``` + +> For integration testing purposes, `openfisca-core` relies on +> [country-template](https://github.com/openfisca/country-template.git) and +> [extension-template](https://github.com/openfisca/extension-template.git). +> Because these packages rely at the same time on `openfisca-core`, they need +> to be installed separately. + To run the entire test suite: ```sh @@ -98,74 +144,10 @@ END ## Documentation -Yet however OpenFisca does not follow a common convention for docstrings, our current toolchain allows to check whether documentation builds correctly and to update it automatically with each contribution to this repository. +OpenFisca’s toolchain checks whether documentation builds correctly and updates it automatically with each contribution to this repository. In the meantime, please take a look at our [contributing guidelines](CONTRIBUTING.md) for some general tips on how to document your contributions, and at our official documentation's [repository](https://github.com/openfisca/openfisca-doc/blob/master/README.md) to in case you want to know how to build it by yourself —and improve it! -### To verify that the documentation still builds correctly - -You can run: - -```sh -make test-doc -``` - -### If it doesn't, or if the doc is already broken. - -Here's how you can fix it: - -1. Clone the documentation, if not yet done: - -``` -make test-doc-checkout -``` - -2. Install the documentation's dependencies, if not yet done: - -``` -make test-doc-install -``` - -3. Create a branch, both in core and in the doc, to correct the problems: - -``` -git checkout -b fix-doc -sh -c "cd doc && git checkout -b `git branch --show-current`" -``` - -4. Fix the offending problems —they could be in core, in the doc, or in both. - -You can test-drive your fixes by checking that each change works as expected: - -``` -make test-doc-build branch=`git branch --show-current` -``` - -5. Commit at each step, so you don't accidentally lose your progress: - -``` -git add -A && git commit -m "Fix outdated argument for Entity" -sh -c "cd doc && git add -A && git commit -m \"Fix outdated argument for Entity\"" -``` - -6. Once you're done, push your changes and cleanup: - -``` -git push origin `git branch --show-current` -sh -c "cd doc && git push origin `git branch --show-current`" -rm -rf doc -``` - -7. Finally, open a pull request both in [core](https://github.com/openfisca/openfisca-core/compare/master...fix-doc) and in the [doc](https://github.com/openfisca/openfisca-doc/compare/master...fix-doc). - -[CircleCI](.circleci/config.yml) will automatically try to build the documentation from the same branch in both core and the doc (in our example "fix-doc") so we can integrate first our changes to core, and then our changes to the doc. - -If no changes were needed to the doc, then your changes to core will be verified against the production version of the doc. - -If your changes concern only the doc, please take a look at the doc's [README](https://github.com/openfisca/openfisca-doc/blob/master/README.md). - -That's it! 🙌 - ## Serving the API OpenFisca-Core provides a Web-API. It is by default served on the `5000` port. @@ -206,7 +188,7 @@ The tracker is activated when these two options are set: * `--tracker-url`: An URL ending with `piwik.php`. It defines the Piwik instance that will receive the tracking information. To use the main OpenFisca Piwik instance, use `https://stats.data.gouv.fr/piwik.php`. * `--tracker-idsite`: An integer. It defines the identifier of the tracked site on your Piwik instance. To use the main OpenFisca piwik instance, use `4`. -* `--tracker-token`: A string. It defines the Piwik API Authentification token to differentiate API calls based on the user IP. Otherwise, all API calls will seem to come from your server. The Piwik API Authentification token can be found in your Piwik interface, when you are logged. +* `--tracker-token`: A string. It defines the Piwik API Authentification token to differentiate API calls based on the user IP. Otherwise, all API calls will seem to come from your server. The Piwik API Authentification token can be found in your Piwik interface when you are logged in. For instance, to run the Web API with the mock country package `openfisca_country_template` and the tracker activated, run: diff --git a/conftest.py b/conftest.py index 3adc794111..fbe03e7d37 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ pytest_plugins = [ "tests.fixtures.appclient", "tests.fixtures.entities", + "tests.fixtures.extensions", "tests.fixtures.simulations", "tests.fixtures.taxbenefitsystems", - ] +] diff --git a/openfisca_core/commons/__init__.py b/openfisca_core/commons/__init__.py index 76b391af1f..1a3d065ee1 100644 --- a/openfisca_core/commons/__init__.py +++ b/openfisca_core/commons/__init__.py @@ -1,15 +1,14 @@ """Common tools for contributors and users. -The tools included in this sub-package are intented, at the same time, to help -contributors who maintain OpenFisca Core, and to help users building their own -systems. +The tools in this sub-package are intended, to help both contributors +to OpenFisca Core and to country packages. Official Public API: - * :class:`.deprecated` * :func:`.apply_thresholds` * :func:`.average_rate` * :func:`.concat` * :func:`.empty_clone` + * :func:`.eval_expression` * :func:`.marginal_rate` * :func:`.stringify_array` * :func:`.switch` @@ -18,8 +17,8 @@ * :class:`.Dummy` Note: - The ``deprecated`` imports are transitional, as so to ensure non-breaking - changes, and could be definitely removed from the codebase in the next + The ``deprecated`` imports are transitional, in order to ensure non-breaking + changes, and could be removed from the codebase in the next major release. Note: @@ -29,11 +28,12 @@ from openfisca_core.commons.formulas import switch # Bad from openfisca_core.commons.decorators import deprecated # Bad - The previous examples provoke cyclic dependency problems, that prevents us - from modularizing the different components of the library, so as to make + + The previous examples provoke cyclic dependency problems, that prevent us + from modularizing the different components of the library, which would make them easier to test and to maintain. - How could them be used after the next major release:: + How they could be used in a future release:: from openfisca_core import commons from openfisca_core.commons import deprecated @@ -51,20 +51,21 @@ """ -# Official Public API - -from .decorators import deprecated # noqa: F401 -from .formulas import apply_thresholds, concat, switch # noqa: F401 -from .misc import empty_clone, stringify_array # noqa: F401 -from .rates import average_rate, marginal_rate # noqa: F401 - -__all__ = ["deprecated"] -__all__ = ["apply_thresholds", "concat", "switch", *__all__] -__all__ = ["empty_clone", "stringify_array", *__all__] -__all__ = ["average_rate", "marginal_rate", *__all__] - -# Deprecated - -from .dummy import Dummy # noqa: F401 - -__all__ = ["Dummy", *__all__] +from . import types +from .dummy import Dummy +from .formulas import apply_thresholds, concat, switch +from .misc import empty_clone, eval_expression, stringify_array +from .rates import average_rate, marginal_rate + +__all__ = [ + "Dummy", + "apply_thresholds", + "average_rate", + "concat", + "empty_clone", + "eval_expression", + "marginal_rate", + "stringify_array", + "switch", + "types", +] diff --git a/openfisca_core/commons/decorators.py b/openfisca_core/commons/decorators.py deleted file mode 100644 index 2041a7f4f7..0000000000 --- a/openfisca_core/commons/decorators.py +++ /dev/null @@ -1,76 +0,0 @@ -import functools -import warnings -import typing -from typing import Any, Callable, TypeVar - -T = Callable[..., Any] -F = TypeVar("F", bound = T) - - -class deprecated: - """Allows (soft) deprecating a functionality of OpenFisca. - - Attributes: - since (:obj:`str`): Since when the functionality is deprecated. - expires (:obj:`str`): When will it be removed forever? - - Args: - since: Since when the functionality is deprecated. - expires: When will it be removed forever? - - Examples: - >>> @deprecated(since = "35.5.0", expires = "in the future") - ... def obsolete(): - ... return "I'm obsolete!" - - >>> repr(obsolete) - '' - - >>> str(obsolete) - '' - - .. versionadded:: 35.6.0 - - """ - - since: str - expires: str - - def __init__(self, since: str, expires: str) -> None: - self.since = since - self.expires = expires - - def __call__(self, function: F) -> F: - """Wraps a function to return another one, decorated. - - Args: - function: The function or method to decorate. - - Returns: - :obj:`callable`: The decorated function. - - Examples: - >>> def obsolete(): - ... return "I'm obsolete!" - - >>> decorator = deprecated( - ... since = "35.5.0", - ... expires = "in the future", - ... ) - - >>> decorator(obsolete) - - - """ - - def wrapper(*args: Any, **kwds: Any) -> Any: - message = [ - f"{function.__qualname__} has been deprecated since", - f"version {self.since}, and will be removed in", - f"{self.expires}.", - ] - warnings.warn(" ".join(message), DeprecationWarning) - return function(*args, **kwds) - - functools.update_wrapper(wrapper, function) - return typing.cast(F, wrapper) diff --git a/openfisca_core/commons/dummy.py b/openfisca_core/commons/dummy.py index 732ed49a65..b9fc31d89f 100644 --- a/openfisca_core/commons/dummy.py +++ b/openfisca_core/commons/dummy.py @@ -1,4 +1,4 @@ -from .decorators import deprecated +import warnings class Dummy: @@ -9,10 +9,17 @@ class Dummy: None: - ... + message = [ + "The 'Dummy' class has been deprecated since version 34.7.0,", + "and will be removed in the future.", + ] + warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2) + + +__all__ = ["Dummy"] diff --git a/openfisca_core/commons/formulas.py b/openfisca_core/commons/formulas.py index a750717f1d..a184ad2dc4 100644 --- a/openfisca_core/commons/formulas.py +++ b/openfisca_core/commons/formulas.py @@ -1,33 +1,33 @@ -from typing import Any, Dict, List +from __future__ import annotations + +from collections.abc import Mapping import numpy -from openfisca_core.types import ArrayLike, ArrayType +from . import types as t def apply_thresholds( - input: ArrayType[float], - thresholds: ArrayLike[float], - choices: ArrayLike[float], - ) -> ArrayType[float]: + input: t.Array[numpy.float32], + thresholds: t.ArrayLike[float], + choices: t.ArrayLike[float], +) -> t.Array[numpy.float32]: """Makes a choice based on an input and thresholds. - From list of ``choices``, it selects one of them based on a list of - inputs, depending on the position of each ``input`` whithin a list of - ``thresholds``. It does so for each ``input`` provided. + From a list of ``choices``, this function selects one of these values + based on a list of inputs, depending on the value of each ``input`` within + a list of ``thresholds``. Args: - input: A list of inputs to make a choice. + input: A list of inputs to make a choice from. thresholds: A list of thresholds to choose. - choices: A list of the possible choices. + choices: A list of the possible values to choose from. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - A list of the choices made. + Array[numpy.float32]: A list of the values chosen. Raises: - :exc:`AssertionError`: When the number of ``thresholds`` (t) and the - number of choices (c) are not either t == c or t == c - 1. + AssertionError: When thresholds and choices are incompatible. Examples: >>> input = numpy.array([4, 5, 6, 7, 8]) @@ -38,8 +38,7 @@ def apply_thresholds( """ - condlist: List[ArrayType[bool]] - + condlist: list[t.Array[numpy.bool_] | bool] condlist = [input <= threshold for threshold in thresholds] if len(condlist) == len(choices) - 1: @@ -47,25 +46,27 @@ def apply_thresholds( # must be true to return it. condlist += [True] - assert len(condlist) == len(choices), \ - " ".join([ - "apply_thresholds must be called with the same number of", - "thresholds than choices, or one more choice", - ]) + msg = ( + "'apply_thresholds' must be called with the same number of thresholds " + "than choices, or one more choice." + ) + assert len(condlist) == len(choices), msg return numpy.select(condlist, choices) -def concat(this: ArrayLike[str], that: ArrayLike[str]) -> ArrayType[str]: - """Concatenates the values of two arrays. +def concat( + this: t.Array[numpy.str_] | t.ArrayLike[object], + that: t.Array[numpy.str_] | t.ArrayLike[object], +) -> t.Array[numpy.str_]: + """Concatenate the values of two arrays. Args: this: An array to concatenate. that: Another array to concatenate. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - An array with the concatenated values. + Array[numpy.str_]: An array with the concatenated values. Examples: >>> this = ["this", "that"] @@ -75,36 +76,39 @@ def concat(this: ArrayLike[str], that: ArrayLike[str]) -> ArrayType[str]: """ - if isinstance(this, numpy.ndarray) and \ - not numpy.issubdtype(this.dtype, numpy.str_): - this = this.astype('str') + if not isinstance(this, numpy.ndarray): + this = numpy.array(this) + + if not numpy.issubdtype(this.dtype, numpy.str_): + this = this.astype("str") - if isinstance(that, numpy.ndarray) and \ - not numpy.issubdtype(that.dtype, numpy.str_): - that = that.astype('str') + if not isinstance(that, numpy.ndarray): + that = numpy.array(that) + + if not numpy.issubdtype(that.dtype, numpy.str_): + that = that.astype("str") return numpy.char.add(this, that) def switch( - conditions: ArrayType[float], - value_by_condition: Dict[float, Any], - ) -> ArrayType[float]: - """Reproduces a switch statement. + conditions: t.Array[numpy.float32] | t.ArrayLike[float], + value_by_condition: Mapping[float, float], +) -> t.Array[numpy.float32]: + """Mimick a switch statement. Given an array of conditions, returns an array of the same size, - replacing each condition item by the corresponding given value. + replacing each condition item with the matching given value. Args: conditions: An array of conditions. value_by_condition: Values to replace for each condition. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - An array with the replaced values. + Array: An array with the replaced values. Raises: - :exc:`AssertionError`: When ``value_by_condition`` is empty. + AssertionError: When ``value_by_condition`` is empty. Examples: >>> conditions = numpy.array([1, 1, 1, 2]) @@ -113,13 +117,13 @@ def switch( array([80, 80, 80, 90]) """ + assert ( + len(value_by_condition) > 0 + ), "'switch' must be called with at least one value." + + condlist = [conditions == condition for condition in value_by_condition] - assert len(value_by_condition) > 0, \ - "switch must be called with at least one value" + return numpy.select(condlist, tuple(value_by_condition.values())) - condlist = [ - conditions == condition - for condition in value_by_condition.keys() - ] - return numpy.select(condlist, value_by_condition.values()) +__all__ = ["apply_thresholds", "concat", "switch"] diff --git a/openfisca_core/commons/misc.py b/openfisca_core/commons/misc.py index 5056389c6e..ba9687619c 100644 --- a/openfisca_core/commons/misc.py +++ b/openfisca_core/commons/misc.py @@ -1,12 +1,13 @@ -from typing import TypeVar +from __future__ import annotations -from openfisca_core.types import ArrayType +import numexpr +import numpy -T = TypeVar("T") +from openfisca_core import types as t -def empty_clone(original: T) -> T: - """Creates an empty instance of the same class of the original object. +def empty_clone(original: object) -> object: + """Create an empty instance of the same class of the original object. Args: original: An object to clone. @@ -29,37 +30,35 @@ def empty_clone(original: T) -> T: """ - Dummy: object - new: T + def __init__(_: object) -> None: ... Dummy = type( "Dummy", (original.__class__,), - {"__init__": lambda self: None}, - ) + {"__init__": __init__}, + ) new = Dummy() new.__class__ = original.__class__ return new -def stringify_array(array: ArrayType) -> str: - """Generates a clean string representation of a numpy array. +def stringify_array(array: None | t.Array[numpy.generic]) -> str: + """Generate a clean string representation of a numpy array. Args: array: An array. Returns: - :obj:`str`: - "None" if the ``array`` is None, the stringified ``array`` otherwise. + str: "None" if the ``array`` is None. + str: The stringified ``array`` otherwise. Examples: >>> import numpy - >>> stringify_array(None) 'None' - >>> array = numpy.array([10, 20.]) + >>> array = numpy.array([10, 20.0]) >>> stringify_array(array) '[10.0, 20.0]' @@ -77,3 +76,34 @@ def stringify_array(array: ArrayType) -> str: return "None" return f"[{', '.join(str(cell) for cell in array)}]" + + +def eval_expression( + expression: str, +) -> str | t.Array[numpy.bool_] | t.Array[numpy.int32] | t.Array[numpy.float32]: + """Evaluate a string expression to a numpy array. + + Args: + expression: An expression to evaluate. + + Returns: + Array: The result of the evaluation. + str: The expression if it couldn't be evaluated. + + Examples: + >>> eval_expression("1 + 2") + array(3, dtype=int32) + + >>> eval_expression("salary") + 'salary' + + """ + + try: + return numexpr.evaluate(expression) + + except (KeyError, TypeError): + return expression + + +__all__ = ["empty_clone", "eval_expression", "stringify_array"] diff --git a/tests/web_api/case_with_extension/__init__.py b/openfisca_core/commons/py.typed similarity index 100% rename from tests/web_api/case_with_extension/__init__.py rename to openfisca_core/commons/py.typed diff --git a/openfisca_core/commons/rates.py b/openfisca_core/commons/rates.py index 487abd3aea..cefc65406e 100644 --- a/openfisca_core/commons/rates.py +++ b/openfisca_core/commons/rates.py @@ -1,19 +1,19 @@ -from typing import Optional +from __future__ import annotations import numpy -from openfisca_core.types import ArrayLike, ArrayType +from . import types as t def average_rate( - target: ArrayType[float], - varying: ArrayLike[float], - trim: Optional[ArrayLike[float]] = None, - ) -> ArrayType[float]: - """Computes the average rate of a target net income. + target: t.Array[numpy.float32], + varying: t.Array[numpy.float32] | t.ArrayLike[float], + trim: None | t.ArrayLike[float] = None, +) -> t.Array[numpy.float32]: + """Compute the average rate of a target net income. Given a ``target`` net income, and according to the ``varying`` gross - income. Optionally, a ``trim`` can be applied consisting on the lower and + income. Optionally, a ``trim`` can be applied consisting of the lower and upper bounds of the average rate to be computed. Note: @@ -25,49 +25,46 @@ def average_rate( trim: The lower and upper bounds of the average rate. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - - The average rate for each target. - - When ``trim`` is provided, values that are out of the provided bounds - are replaced by :obj:`numpy.nan`. + Array[numpy.float32]: The average rate for each target. When ``trim`` + is provided, values that are out of the provided bounds are + replaced by :any:`numpy.nan`. Examples: >>> target = numpy.array([1, 2, 3]) >>> varying = [2, 2, 2] - >>> trim = [-1, .25] + >>> trim = [-1, 0.25] >>> average_rate(target, varying, trim) array([ nan, 0. , -0.5]) """ - average_rate: ArrayType[float] + if not isinstance(varying, numpy.ndarray): + varying = numpy.array(varying, dtype=numpy.float32) average_rate = 1 - target / varying if trim is not None: - average_rate = numpy.where( average_rate <= max(trim), average_rate, numpy.nan, - ) + ) average_rate = numpy.where( average_rate >= min(trim), average_rate, numpy.nan, - ) + ) return average_rate def marginal_rate( - target: ArrayType[float], - varying: ArrayType[float], - trim: Optional[ArrayLike[float]] = None, - ) -> ArrayType[float]: - """Computes the marginal rate of a target net income. + target: t.Array[numpy.float32], + varying: t.Array[numpy.float32] | t.ArrayLike[float], + trim: None | t.ArrayLike[float] = None, +) -> t.Array[numpy.float32]: + """Compute the marginal rate of a target net income. Given a ``target`` net income, and according to the ``varying`` gross income. Optionally, a ``trim`` can be applied consisting of the lower and @@ -82,43 +79,38 @@ def marginal_rate( trim: The lower and upper bounds of the marginal rate. Returns: - :obj:`numpy.ndarray` of :obj:`float`: - - The marginal rate for each target. - - When ``trim`` is provided, values that are out of the provided bounds - are replaced by :obj:`numpy.nan`. + Array[numpy.float32]: The marginal rate for each target. When ``trim`` + is provided, values that are out of the provided bounds are replaced by + :any:`numpy.nan`. Examples: >>> target = numpy.array([1, 2, 3]) >>> varying = numpy.array([1, 2, 4]) - >>> trim = [.25, .75] + >>> trim = [0.25, 0.75] >>> marginal_rate(target, varying, trim) array([nan, 0.5]) """ - marginal_rate: ArrayType[float] + if not isinstance(varying, numpy.ndarray): + varying = numpy.array(varying, dtype=numpy.float32) - marginal_rate = ( - + 1 - - (target[:-1] - - target[1:]) / (varying[:-1] - - varying[1:]) - ) + marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:]) if trim is not None: - marginal_rate = numpy.where( marginal_rate <= max(trim), marginal_rate, numpy.nan, - ) + ) marginal_rate = numpy.where( marginal_rate >= min(trim), marginal_rate, numpy.nan, - ) + ) return marginal_rate + + +__all__ = ["average_rate", "marginal_rate"] diff --git a/openfisca_core/commons/tests/test_decorators.py b/openfisca_core/commons/tests/test_decorators.py deleted file mode 100644 index 04c5ce3d91..0000000000 --- a/openfisca_core/commons/tests/test_decorators.py +++ /dev/null @@ -1,20 +0,0 @@ -import re - -import pytest - -from openfisca_core.commons import deprecated - - -def test_deprecated(): - """The decorated function throws a deprecation warning when used.""" - - since = "yesterday" - expires = "doomsday" - match = re.compile(f"^.*{since}.*{expires}.*$") - - @deprecated(since, expires) - def function(a: int, b: float) -> float: - return a + b - - with pytest.warns(DeprecationWarning, match = match): - assert function(1, 2.) == 3. diff --git a/openfisca_core/commons/tests/test_dummy.py b/openfisca_core/commons/tests/test_dummy.py index d4ecec3842..dfe04b3e44 100644 --- a/openfisca_core/commons/tests/test_dummy.py +++ b/openfisca_core/commons/tests/test_dummy.py @@ -3,7 +3,7 @@ from openfisca_core.commons import Dummy -def test_dummy_deprecation(): +def test_dummy_deprecation() -> None: """Dummy throws a deprecation warning when instantiated.""" with pytest.warns(DeprecationWarning): diff --git a/openfisca_core/commons/tests/test_formulas.py b/openfisca_core/commons/tests/test_formulas.py index f05725cb80..130df9505b 100644 --- a/openfisca_core/commons/tests/test_formulas.py +++ b/openfisca_core/commons/tests/test_formulas.py @@ -5,8 +5,8 @@ from openfisca_core import commons -def test_apply_thresholds_when_several_inputs(): - """Makes a choice for any given input.""" +def test_apply_thresholds_when_several_inputs() -> None: + """Make a choice for any given input.""" input_ = numpy.array([4, 5, 6, 7, 8, 9, 10]) thresholds = [5, 7, 9] @@ -17,8 +17,8 @@ def test_apply_thresholds_when_several_inputs(): assert_array_equal(result, [10, 10, 15, 15, 20, 20, 25]) -def test_apply_thresholds_when_too_many_thresholds(): - """Raises an AssertionError when thresholds > choices.""" +def test_apply_thresholds_when_too_many_thresholds() -> None: + """Raise an AssertionError when thresholds > choices.""" input_ = numpy.array([6]) thresholds = [5, 7, 9, 11] @@ -28,8 +28,8 @@ def test_apply_thresholds_when_too_many_thresholds(): assert commons.apply_thresholds(input_, thresholds, choices) -def test_apply_thresholds_when_too_many_choices(): - """Raises an AssertionError when thresholds < choices - 1.""" +def test_apply_thresholds_when_too_many_choices() -> None: + """Raise an AssertionError when thresholds < choices - 1.""" input_ = numpy.array([6]) thresholds = [5, 7] @@ -39,8 +39,8 @@ def test_apply_thresholds_when_too_many_choices(): assert commons.apply_thresholds(input_, thresholds, choices) -def test_concat_when_this_is_array_not_str(): - """Casts ``this`` to ``str`` when it is a numpy array other than string.""" +def test_concat_when_this_is_array_not_str() -> None: + """Cast ``this`` to ``str`` when it is a NumPy array other than string.""" this = numpy.array([1, 2]) that = numpy.array(["la", "o"]) @@ -50,8 +50,8 @@ def test_concat_when_this_is_array_not_str(): assert_array_equal(result, ["1la", "2o"]) -def test_concat_when_that_is_array_not_str(): - """Casts ``that`` to ``str`` when it is a numpy array other than string.""" +def test_concat_when_that_is_array_not_str() -> None: + """Cast ``that`` to ``str`` when it is a NumPy array other than string.""" this = numpy.array(["ho", "cha"]) that = numpy.array([1, 2]) @@ -61,18 +61,19 @@ def test_concat_when_that_is_array_not_str(): assert_array_equal(result, ["ho1", "cha2"]) -def test_concat_when_args_not_str_array_like(): - """Raises a TypeError when args are not a string array-like object.""" +def test_concat_when_args_not_str_array_like() -> None: + """Cast ``this`` and ``that`` to a NumPy array or strings.""" this = (1, 2) that = (3, 4) - with pytest.raises(TypeError): - commons.concat(this, that) + result = commons.concat(this, that) + + assert_array_equal(result, ["13", "24"]) -def test_switch_when_values_are_empty(): - """Raises an AssertionError when the values are empty.""" +def test_switch_when_values_are_empty() -> None: + """Raise an AssertionError when the values are empty.""" conditions = [1, 1, 1, 2] value_by_condition = {} diff --git a/openfisca_core/commons/tests/test_rates.py b/openfisca_core/commons/tests/test_rates.py index e603a05241..c266582fc5 100644 --- a/openfisca_core/commons/tests/test_rates.py +++ b/openfisca_core/commons/tests/test_rates.py @@ -1,26 +1,28 @@ +import math + import numpy from numpy.testing import assert_array_equal from openfisca_core import commons -def test_average_rate_when_varying_is_zero(): - """Yields infinity when the varying gross income crosses zero.""" +def test_average_rate_when_varying_is_zero() -> None: + """Yield infinity when the varying gross income crosses zero.""" target = numpy.array([1, 2, 3]) varying = [0, 0, 0] result = commons.average_rate(target, varying) - assert_array_equal(result, [- numpy.inf, - numpy.inf, - numpy.inf]) + assert_array_equal(result, numpy.array([-math.inf, -math.inf, -math.inf])) -def test_marginal_rate_when_varying_is_zero(): - """Yields infinity when the varying gross income crosses zero.""" +def test_marginal_rate_when_varying_is_zero() -> None: + """Yield infinity when the varying gross income crosses zero.""" target = numpy.array([1, 2, 3]) varying = numpy.array([0, 0, 0]) result = commons.marginal_rate(target, varying) - assert_array_equal(result, [numpy.inf, numpy.inf]) + assert_array_equal(result, numpy.array([math.inf, math.inf])) diff --git a/openfisca_core/commons/types.py b/openfisca_core/commons/types.py new file mode 100644 index 0000000000..39c067f455 --- /dev/null +++ b/openfisca_core/commons/types.py @@ -0,0 +1,3 @@ +from openfisca_core.types import Array, ArrayLike + +__all__ = ["Array", "ArrayLike"] diff --git a/openfisca_core/data_storage/in_memory_storage.py b/openfisca_core/data_storage/in_memory_storage.py index bd40460a56..0808612ba8 100644 --- a/openfisca_core/data_storage/in_memory_storage.py +++ b/openfisca_core/data_storage/in_memory_storage.py @@ -1,20 +1,19 @@ import numpy from openfisca_core import periods +from openfisca_core.periods import DateUnit class InMemoryStorage: - """ - Low-level class responsible for storing and retrieving calculated vectors in memory - """ + """Low-level class responsible for storing and retrieving calculated vectors in memory.""" - def __init__(self, is_eternal = False): + def __init__(self, is_eternal=False) -> None: self._arrays = {} self.is_eternal = is_eternal def get(self, period): if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) values = self._arrays.get(period) @@ -22,43 +21,43 @@ def get(self, period): return None return values - def put(self, value, period): + def put(self, value, period) -> None: if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) self._arrays[period] = value - def delete(self, period = None): + def delete(self, period=None) -> None: if period is None: self._arrays = {} return if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) self._arrays = { period_item: value for period_item, value in self._arrays.items() if not period.contains(period_item) - } + } def get_known_periods(self): return self._arrays.keys() def get_memory_usage(self): if not self._arrays: - return dict( - nb_arrays = 0, - total_nb_bytes = 0, - cell_size = numpy.nan, - ) + return { + "nb_arrays": 0, + "total_nb_bytes": 0, + "cell_size": numpy.nan, + } nb_arrays = len(self._arrays) array = next(iter(self._arrays.values())) - return dict( - nb_arrays = nb_arrays, - total_nb_bytes = array.nbytes * nb_arrays, - cell_size = array.itemsize, - ) + return { + "nb_arrays": nb_arrays, + "total_nb_bytes": array.nbytes * nb_arrays, + "cell_size": array.itemsize, + } diff --git a/openfisca_core/data_storage/on_disk_storage.py b/openfisca_core/data_storage/on_disk_storage.py index 10d4696b58..9133db2376 100644 --- a/openfisca_core/data_storage/on_disk_storage.py +++ b/openfisca_core/data_storage/on_disk_storage.py @@ -5,14 +5,15 @@ from openfisca_core import periods from openfisca_core.indexed_enums import EnumArray +from openfisca_core.periods import DateUnit class OnDiskStorage: - """ - Low-level class responsible for storing and retrieving calculated vectors on disk - """ + """Low-level class responsible for storing and retrieving calculated vectors on disk.""" - def __init__(self, storage_dir, is_eternal = False, preserve_storage_dir = False): + def __init__( + self, storage_dir, is_eternal=False, preserve_storage_dir=False + ) -> None: self._files = {} self._enums = {} self.is_eternal = is_eternal @@ -23,12 +24,11 @@ def _decode_file(self, file): enum = self._enums.get(file) if enum is not None: return EnumArray(numpy.load(file), enum) - else: - return numpy.load(file) + return numpy.load(file) def get(self, period): if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) values = self._files.get(period) @@ -36,26 +36,26 @@ def get(self, period): return None return self._decode_file(values) - def put(self, value, period): + def put(self, value, period) -> None: if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) filename = str(period) - path = os.path.join(self.storage_dir, filename) + '.npy' + path = os.path.join(self.storage_dir, filename) + ".npy" if isinstance(value, EnumArray): self._enums[path] = value.possible_values value = value.view(numpy.ndarray) numpy.save(path, value) self._files[period] = path - def delete(self, period = None): + def delete(self, period=None) -> None: if period is None: self._files = {} return if self.is_eternal: - period = periods.period(periods.ETERNITY) + period = periods.period(DateUnit.ETERNITY) period = periods.period(period) if period is not None: @@ -63,23 +63,23 @@ def delete(self, period = None): period_item: value for period_item, value in self._files.items() if not period.contains(period_item) - } + } def get_known_periods(self): return self._files.keys() - def restore(self): + def restore(self) -> None: self._files = files = {} # Restore self._files from content of storage_dir. for filename in os.listdir(self.storage_dir): - if not filename.endswith('.npy'): + if not filename.endswith(".npy"): continue path = os.path.join(self.storage_dir, filename) - filename_core = filename.rsplit('.', 1)[0] + filename_core = filename.rsplit(".", 1)[0] period = periods.period(filename_core) files[period] = path - def __del__(self): + def __del__(self) -> None: if self.preserve_storage_dir: return shutil.rmtree(self.storage_dir) # Remove the holder temporary files diff --git a/openfisca_core/entities/__init__.py b/openfisca_core/entities/__init__.py index c388037a38..927aad63d6 100644 --- a/openfisca_core/entities/__init__.py +++ b/openfisca_core/entities/__init__.py @@ -87,8 +87,8 @@ from .entity import Entity # noqa: F401 from .group_entity import GroupEntity # noqa: F401 -from .role import Role # noqa: F401 from .helpers import build_entity, check_role_validity # noqa: F401 +from .role import Role # noqa: F401 __all__ = ["Entity", "GroupEntity", "Role"] __all__ = ["build_entity", "check_role_validity", *__all__] diff --git a/openfisca_core/entities/_core_entity.py b/openfisca_core/entities/_core_entity.py new file mode 100644 index 0000000000..f44353e112 --- /dev/null +++ b/openfisca_core/entities/_core_entity.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import ClassVar + +import abc +import os + +from . import types as t +from .role import Role + + +class _CoreEntity: + """Base class to build entities from. + + Args: + __key: A key to identify the ``_CoreEntity``. + __plural: The ``key`` pluralised. + __label: A summary description. + __doc: A full description. + *__args: Additional arguments. + + """ + + #: A key to identify the ``_CoreEntity``. + key: t.EntityKey + + #: The ``key`` pluralised. + plural: t.EntityPlural + + #: A summary description. + label: str + + #: A full description. + doc: str + + #: Whether the ``_CoreEntity`` is a person or not. + is_person: ClassVar[bool] + + #: A ``TaxBenefitSystem`` instance. + _tax_benefit_system: None | t.TaxBenefitSystem = None + + @abc.abstractmethod + def __init__( + self, + __key: str, + __plural: str, + __label: str, + __doc: str, + *__args: object, + ) -> None: ... + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.key})" + + def set_tax_benefit_system(self, tax_benefit_system: t.TaxBenefitSystem) -> None: + """A ``_CoreEntity`` belongs to a ``TaxBenefitSystem``.""" + self._tax_benefit_system = tax_benefit_system + + def get_variable( + self, + variable_name: t.VariableName, + check_existence: bool = False, + ) -> t.Variable | None: + """Get ``variable_name`` from ``variables``. + + Args: + variable_name: The ``Variable`` to be found. + check_existence: Was the ``Variable`` found? + + Returns: + Variable: When the ``Variable`` exists. + None: When the ``Variable`` doesn't exist. + + Raises: + ValueError: When ``check_existence`` is ``True`` and + the ``Variable`` doesn't exist. + + """ + + if self._tax_benefit_system is None: + msg = "You must set 'tax_benefit_system' before calling this method." + raise ValueError( + msg, + ) + return self._tax_benefit_system.get_variable(variable_name, check_existence) + + def check_variable_defined_for_entity(self, variable_name: t.VariableName) -> None: + """Check if ``variable_name`` is defined for ``self``. + + Args: + variable_name: The ``Variable`` to be found. + + Returns: + Variable: When the ``Variable`` exists. + None: When the :attr:`_tax_benefit_system` is not set. + + Raises: + ValueError: When the ``Variable`` exists but is defined + for another ``Entity``. + + """ + + entity: None | t.CoreEntity = None + variable: None | t.Variable = self.get_variable( + variable_name, + check_existence=True, + ) + + if variable is not None: + entity = variable.entity + + if entity is None: + return + + if entity.key != self.key: + message = ( + f"You tried to compute the variable '{variable_name}' for", + f"the entity '{self.plural}'; however the variable", + f"'{variable_name}' is defined for '{entity.plural}'.", + "Learn more about entities in our documentation:", + ".", + ) + raise ValueError(os.linesep.join(message)) + + @staticmethod + def check_role_validity(role: object) -> None: + """Check if ``role`` is an instance of ``Role``. + + Args: + role: Any object. + + Raises: + ValueError: When ``role`` is not a ``Role``. + + """ + + if role is not None and not isinstance(role, Role): + msg = f"{role} is not a valid role" + raise ValueError(msg) + + +__all__ = ["_CoreEntity"] diff --git a/openfisca_core/entities/_description.py b/openfisca_core/entities/_description.py new file mode 100644 index 0000000000..78634ca270 --- /dev/null +++ b/openfisca_core/entities/_description.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import dataclasses +import textwrap + + +@dataclasses.dataclass(frozen=True) +class _Description: + """A ``Role``'s description. + + Examples: + >>> data = { + ... "key": "parent", + ... "label": "Parents", + ... "plural": "parents", + ... "doc": "\t\t\tThe one/two adults in charge of the household.", + ... } + + >>> description = _Description(**data) + + >>> repr(_Description) + "" + + >>> repr(description) + "_Description(key='parent', plural='parents', label='Parents', ...)" + + >>> str(description) + "_Description(key='parent', plural='parents', label='Parents', ...)" + + >>> {description} + {_Description(key='parent', plural='parents', label='Parents', doc=...} + + >>> description.key + 'parent' + + """ + + #: A key to identify the ``Role``. + key: str + + #: The ``key`` pluralised. + plural: None | str = None + + #: A summary description. + label: None | str = None + + #: A full description, non-indented. + doc: None | str = None + + def __post_init__(self) -> None: + if self.doc is not None: + object.__setattr__(self, "doc", textwrap.dedent(self.doc)) + + +__all__ = ["_Description"] diff --git a/openfisca_core/entities/_role_builder.py b/openfisca_core/entities/_role_builder.py index f1cf5d6bfb..fbed5dbfd9 100644 --- a/openfisca_core/entities/_role_builder.py +++ b/openfisca_core/entities/_role_builder.py @@ -1,12 +1,13 @@ from __future__ import annotations -import dataclasses from typing import Iterable, Optional, Sequence, Type from openfisca_core.types import HasPlural, RoleLike, SupportsRole +import dataclasses + -@dataclasses.dataclass(frozen = True) +@dataclasses.dataclass(frozen=True) class RoleBuilder: """Builds roles & sub-roles from a given input. @@ -97,9 +98,8 @@ def build(self, item: RoleLike) -> SupportsRole: if subroles: role.subroles = [ - self.build(RoleLike({"key": key, "max": 1})) - for key in subroles - ] + self.build(RoleLike({"key": key, "max": 1})) for key in subroles + ] role.max = len(role.subroles) return role diff --git a/openfisca_core/entities/_variable_proxy.py b/openfisca_core/entities/_variable_proxy.py index 8089171992..7381ff90be 100644 --- a/openfisca_core/entities/_variable_proxy.py +++ b/openfisca_core/entities/_variable_proxy.py @@ -1,13 +1,13 @@ from __future__ import annotations -import functools -import os from typing import Any, Optional, Type - from typing_extensions import Protocol from openfisca_core.types import HasPlural, HasVariables, SupportsFormula +import functools +import os + DOC_URL = "https://openfisca.org/doc/coding-the-legislation" E = HasPlural @@ -19,10 +19,10 @@ class Query(Protocol): """A dummy class to "duck-check" :meth:`.TaxBenefitSystem.get_variable`.""" def __call__( - self, - __arg1: str, - __arg2: bool = False, - ) -> Optional["VariableProxy"]: + self, + __arg1: str, + __arg2: bool = False, + ) -> Optional["VariableProxy"]: """See comment above.""" ... @@ -90,7 +90,7 @@ def __get__(self, entity: E, type: Type[E]) -> Optional[VariableProxy]: self.entity, "tax_benefit_system", None, - ) + ) if self.tax_benefit_system is None: return None @@ -132,8 +132,8 @@ def exists(self) -> VariableProxy: self.query = functools.partial( self.query, - check_existence = True, - ) + check_existence=True, + ) return self @@ -146,7 +146,7 @@ def isdefined(self) -> VariableProxy: self.query = functools.partial( self._isdefined, self.query, - ) + ) return self @@ -163,13 +163,15 @@ def _isdefined(self, query: Query, variable_name: str, **any: Any) -> Any: return None if self.entity != variable.entity: - message = os.linesep.join([ - f"You tried to compute the variable '{variable_name}' for", - f"the entity '{self.entity.plural}'; however the variable", - f"'{variable_name}' is defined for the entity", - f"'{variable.entity.plural}'. Learn more about entities", - f"in our documentation: <{DOC_URL}/50_entities.html>.", - ]) + message = os.linesep.join( + [ + f"You tried to compute the variable '{variable_name}' for", + f"the entity '{self.entity.plural}'; however the variable", + f"'{variable_name}' is defined for the entity", + f"'{variable.entity.plural}'. Learn more about entities", + f"in our documentation: <{DOC_URL}/50_entities.html>.", + ] + ) raise ValueError(message) diff --git a/openfisca_core/entities/entity.py b/openfisca_core/entities/entity.py index 82e6673bad..570bfdcca8 100644 --- a/openfisca_core/entities/entity.py +++ b/openfisca_core/entities/entity.py @@ -1,14 +1,11 @@ +from typing import Any, Iterator, Optional, Tuple + +from openfisca_core.types import Descriptor, HasHolders, HasVariables, SupportsFormula + import dataclasses import textwrap -from typing import Any, Iterator, Optional, Tuple from openfisca_core.commons import deprecated -from openfisca_core.types import ( - Descriptor, - HasHolders, - HasVariables, - SupportsFormula, - ) from .. import entities from ._variable_proxy import VariableProxy @@ -126,7 +123,7 @@ class Entity: "is_person", "_population", "_tax_benefit_system", - ] + ] key: str plural: str @@ -160,10 +157,10 @@ def variables(self) -> Optional[VariableProxy]: return self._variables _variables: Descriptor[VariableProxy] = dataclasses.field( - init = False, - compare = False, - default = VariableProxy(), - ) + init=False, + compare=False, + default=VariableProxy(), + ) def __post_init__(self, *__: Any) -> None: self.doc = textwrap.dedent(self.doc) @@ -180,13 +177,13 @@ def __iter__(self) -> Iterator[Tuple[str, Any]]: (item, self.__getattribute__(item)) for item in self.__slots__ if not item.startswith("_") - ) + ) - @deprecated(since = "35.7.0", expires = "the future") + @deprecated(since="35.7.0", expires="the future") def set_tax_benefit_system( - self, - tax_benefit_system: HasVariables, - ) -> None: + self, + tax_benefit_system: HasVariables, + ) -> None: """Sets ``_tax_benefit_system``. Args: @@ -201,12 +198,12 @@ def set_tax_benefit_system( self.tax_benefit_system = tax_benefit_system - @deprecated(since = "35.7.0", expires = "the future") + @deprecated(since="35.7.0", expires="the future") def get_variable( - self, - variable_name: str, - check_existence: bool = False, - ) -> Optional[SupportsFormula]: + self, + variable_name: str, + check_existence: bool = False, + ) -> Optional[SupportsFormula]: """Gets ``variable_name`` from ``variables``. Args: @@ -237,11 +234,11 @@ def get_variable( return self.variables.get(variable_name) - @deprecated(since = "35.7.0", expires = "the future") + @deprecated(since="35.7.0", expires="the future") def check_variable_defined_for_entity( - self, - variable_name: str, - ) -> Optional[SupportsFormula]: + self, + variable_name: str, + ) -> Optional[SupportsFormula]: """Checks if ``variable_name`` is defined for ``self``. Args: @@ -270,7 +267,7 @@ def check_variable_defined_for_entity( return self.variables.isdefined().get(variable_name) @staticmethod - @deprecated(since = "35.7.0", expires = "the future") + @deprecated(since="35.7.0", expires="the future") def check_role_validity(role: Any) -> None: """Checks if ``role`` is an instance of :class:`.Role`. diff --git a/openfisca_core/entities/group_entity.py b/openfisca_core/entities/group_entity.py index 12ab29ac44..066273d879 100644 --- a/openfisca_core/entities/group_entity.py +++ b/openfisca_core/entities/group_entity.py @@ -1,15 +1,16 @@ -import dataclasses -import textwrap from typing import Any, Dict, Iterable, Optional, Sequence from openfisca_core.types import Builder, RoleLike, SupportsRole +import dataclasses +import textwrap + from ._role_builder import RoleBuilder from .entity import Entity from .role import Role -@dataclasses.dataclass(repr = False) +@dataclasses.dataclass(repr=False) class GroupEntity(Entity): """Represents a :class:`.GroupEntity` on which calculations can be run. @@ -104,7 +105,7 @@ class GroupEntity(Entity): "_roles", "_roles_map", "_flattened_roles", - ] + ] __: dataclasses.InitVar[Iterable[RoleLike]] @@ -127,10 +128,8 @@ def flattened_roles(self) -> Sequence[SupportsRole]: @flattened_roles.setter def flattened_roles(self, roles: Sequence[SupportsRole]) -> None: self._flattened_roles = tuple( - array - for role in roles - for array in role.subroles or [role] - ) + array for role in roles for array in role.subroles or [role] + ) def __post_init__(self, *__: Iterable[RoleLike]) -> None: self.doc = textwrap.dedent(self.doc) diff --git a/openfisca_core/entities/helpers.py b/openfisca_core/entities/helpers.py index 71f203aef6..cf566f3df3 100644 --- a/openfisca_core/entities/helpers.py +++ b/openfisca_core/entities/helpers.py @@ -1,20 +1,20 @@ from typing import Any, Iterable, Optional -from openfisca_core.types import RoleLike, HasPlural, SupportsRole +from openfisca_core.types import HasPlural, RoleLike, SupportsRole from .entity import Entity from .group_entity import GroupEntity def build_entity( - key: str, - plural: str, - label: str, - doc: str = "", - roles: Optional[Iterable[RoleLike]] = None, - is_person: bool = False, - class_override: Optional[Any] = None, - ) -> HasPlural: + key: str, + plural: str, + label: str, + doc: str = "", + roles: Optional[Iterable[RoleLike]] = None, + is_person: bool = False, + class_override: Optional[Any] = None, +) -> HasPlural: """Builds an :class:`.Entity` or a :class:`.GroupEntity`. Args: diff --git a/tests/web_api/loader/__init__.py b/openfisca_core/entities/py.typed similarity index 100% rename from tests/web_api/loader/__init__.py rename to openfisca_core/entities/py.typed diff --git a/openfisca_core/entities/role.py b/openfisca_core/entities/role.py index 50ae88c65a..5fdea85bb0 100644 --- a/openfisca_core/entities/role.py +++ b/openfisca_core/entities/role.py @@ -1,18 +1,14 @@ from __future__ import annotations -import dataclasses -import textwrap - from typing import Any, Iterator, Optional, Sequence, Tuple -from openfisca_core.types import ( - HasPlural, - RoleLike, - SupportsRole, - ) +from openfisca_core.types import HasPlural, RoleLike, SupportsRole + +import dataclasses +import textwrap -@dataclasses.dataclass(init = False) +@dataclasses.dataclass(init=False) class Role: """Role of an :class:`.Entity` within a :class:`.GroupEntity`. @@ -79,11 +75,11 @@ class Role: def __init__(self, description: RoleLike, entity: HasPlural) -> None: self.entity = entity - self.key = description['key'] - self.plural = description.get('plural') - self.label = description.get('label') - self.doc = textwrap.dedent(str(description.get('doc', ""))) - self.max = description.get('max') + self.key = description["key"] + self.plural = description.get("plural") + self.label = description.get("label") + self.doc = textwrap.dedent(str(description.get("doc", ""))) + self.max = description.get("max") self.subroles = None def __repr__(self) -> str: diff --git a/openfisca_core/entities/tests/test_helpers.py b/openfisca_core/entities/tests/test_helpers.py index e50f9a8cfc..c6eedb122c 100644 --- a/openfisca_core/entities/tests/test_helpers.py +++ b/openfisca_core/entities/tests/test_helpers.py @@ -7,7 +7,7 @@ def test_build_entity_without_roles(): """Raises a ArgumentError when it's called without roles.""" with pytest.raises(ValueError): - entities.build_entity("", "", "", roles = None) + entities.build_entity("", "", "", roles=None) def test_check_role_validity_when_not_role(): diff --git a/openfisca_core/entities/tests/test_variable_proxy.py b/openfisca_core/entities/tests/test_variable_proxy.py index f190ea169f..f9f571819b 100644 --- a/openfisca_core/entities/tests/test_variable_proxy.py +++ b/openfisca_core/entities/tests/test_variable_proxy.py @@ -33,8 +33,8 @@ def ThisVar(entity): "definition_period": "month", "value_type": float, "entity": entity, - }, - ) + }, + ) @pytest.fixture @@ -62,7 +62,7 @@ def tbs(entity, ThisVar, ThatVar): def test_variables_without_variable_name(variables): """Raises a TypeError when called without ``variable_name``.""" - with pytest.raises(TypeError, match = "'variable_name'"): + with pytest.raises(TypeError, match="'variable_name'"): variables.get() @@ -75,7 +75,7 @@ def test_variables_without_owner(variables): def test_variables_setter(entity): """Raises AttributeError when tryng to set ``variables``.""" - with pytest.raises(AttributeError, match = "can't set attribute"): + with pytest.raises(AttributeError, match="can't set attribute"): entity.variables = object() @@ -113,7 +113,7 @@ def test_variables_when_doesnt_exist_and_check_exists(entity, tbs): entity.tax_benefit_system = tbs - with pytest.raises(VariableNotFoundError, match = "'OtherVar'"): + with pytest.raises(VariableNotFoundError, match="'OtherVar'"): entity.variables.exists().get("OtherVar") @@ -130,7 +130,7 @@ def test_variables_when_exists_and_not_defined_for(entity, tbs): entity.tax_benefit_system = tbs - with pytest.raises(ValueError, match = "'martians'"): + with pytest.raises(ValueError, match="'martians'"): entity.variables.isdefined().get("ThatVar") @@ -139,7 +139,7 @@ def test_variables_when_doesnt_exist_and_check_defined_for(entity, tbs): entity.tax_benefit_system = tbs - with pytest.raises(VariableNotFoundError, match = "'OtherVar'"): + with pytest.raises(VariableNotFoundError, match="'OtherVar'"): entity.variables.isdefined().get("OtherVar") diff --git a/openfisca_core/entities/types.py b/openfisca_core/entities/types.py new file mode 100644 index 0000000000..ef6af9024f --- /dev/null +++ b/openfisca_core/entities/types.py @@ -0,0 +1,42 @@ +from typing_extensions import Required, TypedDict + +from openfisca_core.types import ( + CoreEntity, + EntityKey, + EntityPlural, + GroupEntity, + Role, + RoleKey, + RolePlural, + SingleEntity, + TaxBenefitSystem, + Variable, + VariableName, +) + +# Entities + + +class RoleParams(TypedDict, total=False): + key: Required[str] + plural: str + label: str + doc: str + max: int + subroles: list[str] + + +__all__ = [ + "CoreEntity", + "EntityKey", + "EntityPlural", + "GroupEntity", + "Role", + "RoleKey", + "RoleParams", + "RolePlural", + "SingleEntity", + "TaxBenefitSystem", + "Variable", + "VariableName", +] diff --git a/openfisca_core/errors/__init__.py b/openfisca_core/errors/__init__.py index ccd19af9b2..2c4d438116 100644 --- a/openfisca_core/errors/__init__.py +++ b/openfisca_core/errors/__init__.py @@ -21,13 +21,38 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .cycle_error import CycleError # noqa: F401 -from .empty_argument_error import EmptyArgumentError # noqa: F401 -from .nan_creation_error import NaNCreationError # noqa: F401 -from .parameter_not_found_error import ParameterNotFoundError, ParameterNotFoundError as ParameterNotFound # noqa: F401 -from .parameter_parsing_error import ParameterParsingError # noqa: F401 -from .period_mismatch_error import PeriodMismatchError # noqa: F401 -from .situation_parsing_error import SituationParsingError # noqa: F401 -from .spiral_error import SpiralError # noqa: F401 -from .variable_name_config_error import VariableNameConflictError, VariableNameConflictError as VariableNameConflict # noqa: F401 -from .variable_not_found_error import VariableNotFoundError, VariableNotFoundError as VariableNotFound # noqa: F401 +from .cycle_error import CycleError +from .empty_argument_error import EmptyArgumentError +from .nan_creation_error import NaNCreationError +from .parameter_not_found_error import ( + ParameterNotFoundError, + ParameterNotFoundError as ParameterNotFound, +) +from .parameter_parsing_error import ParameterParsingError +from .period_mismatch_error import PeriodMismatchError +from .situation_parsing_error import SituationParsingError +from .spiral_error import SpiralError +from .variable_name_config_error import ( + VariableNameConflictError, + VariableNameConflictError as VariableNameConflict, +) +from .variable_not_found_error import ( + VariableNotFoundError, + VariableNotFoundError as VariableNotFound, +) + +__all__ = [ + "CycleError", + "EmptyArgumentError", + "NaNCreationError", + "ParameterNotFound", # Deprecated alias for "ParameterNotFoundError + "ParameterNotFoundError", + "ParameterParsingError", + "PeriodMismatchError", + "SituationParsingError", + "SpiralError", + "VariableNameConflict", # Deprecated alias for "VariableNameConflictError" + "VariableNameConflictError", + "VariableNotFound", # Deprecated alias for "VariableNotFoundError" + "VariableNotFoundError", +] diff --git a/openfisca_core/errors/cycle_error.py b/openfisca_core/errors/cycle_error.py index b4d44b5993..b81cc7b3f9 100644 --- a/openfisca_core/errors/cycle_error.py +++ b/openfisca_core/errors/cycle_error.py @@ -1,4 +1,2 @@ class CycleError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/empty_argument_error.py b/openfisca_core/errors/empty_argument_error.py index d3bcddbf9a..960d8d28c2 100644 --- a/openfisca_core/errors/empty_argument_error.py +++ b/openfisca_core/errors/empty_argument_error.py @@ -1,6 +1,7 @@ +import typing + import os import traceback -import typing import numpy @@ -11,12 +12,12 @@ class EmptyArgumentError(IndexError): message: str def __init__( - self, - class_name: str, - method_name: str, - arg_name: str, - arg_value: typing.Union[typing.List, numpy.ndarray] - ) -> None: + self, + class_name: str, + method_name: str, + arg_name: str, + arg_value: typing.Union[list, numpy.ndarray], + ) -> None: message = [ f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n", f">>> {arg_name}", @@ -31,7 +32,7 @@ def __init__( "- Mention us via https://twitter.com/openfisca", "- Drop us a line to contact@openfisca.org\n", "😃", - ] + ] stacktrace = os.linesep.join(traceback.format_stack()) self.message = os.linesep.join([f" {line}" for line in message]) self.message = os.linesep.join([stacktrace, self.message]) diff --git a/openfisca_core/errors/nan_creation_error.py b/openfisca_core/errors/nan_creation_error.py index dfd1b7af7e..373e391517 100644 --- a/openfisca_core/errors/nan_creation_error.py +++ b/openfisca_core/errors/nan_creation_error.py @@ -1,4 +1,2 @@ class NaNCreationError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/parameter_not_found_error.py b/openfisca_core/errors/parameter_not_found_error.py index 624ef490e6..bad33c89f4 100644 --- a/openfisca_core/errors/parameter_not_found_error.py +++ b/openfisca_core/errors/parameter_not_found_error.py @@ -1,21 +1,16 @@ class ParameterNotFoundError(AttributeError): - """ - Exception raised when a parameter is not found in the parameters. - """ + """Exception raised when a parameter is not found in the parameters.""" - def __init__(self, name, instant_str, variable_name = None): - """ - :param name: Name of the parameter + def __init__(self, name, instant_str, variable_name=None) -> None: + """:param name: Name of the parameter :param instant_str: Instant where the parameter does not exist, in the format `YYYY-MM-DD`. :param variable_name: If the parameter was queried during the computation of a variable, name of that variable. """ self.name = name self.instant_str = instant_str self.variable_name = variable_name - message = "The parameter '{}'".format(name) + message = f"The parameter '{name}'" if variable_name is not None: - message += " requested by variable '{}'".format(variable_name) - message += ( - " was not found in the {} tax and benefit system." - ).format(instant_str) - super(ParameterNotFoundError, self).__init__(message) + message += f" requested by variable '{variable_name}'" + message += f" was not found in the {instant_str} tax and benefit system." + super().__init__(message) diff --git a/openfisca_core/errors/parameter_parsing_error.py b/openfisca_core/errors/parameter_parsing_error.py index aa92124290..7628e42d86 100644 --- a/openfisca_core/errors/parameter_parsing_error.py +++ b/openfisca_core/errors/parameter_parsing_error.py @@ -2,21 +2,17 @@ class ParameterParsingError(Exception): - """ - Exception raised when a parameter cannot be parsed. - """ + """Exception raised when a parameter cannot be parsed.""" - def __init__(self, message, file = None, traceback = None): - """ - :param message: Error message + def __init__(self, message, file=None, traceback=None) -> None: + """:param message: Error message :param file: Parameter file which caused the error (optional) :param traceback: Traceback (optional) """ if file is not None: - message = os.linesep.join([ - "Error parsing parameter file '{}':".format(file), - message - ]) + message = os.linesep.join( + [f"Error parsing parameter file '{file}':", message], + ) if traceback is not None: message = os.linesep.join([traceback, message]) - super(ParameterParsingError, self).__init__(message) + super().__init__(message) diff --git a/openfisca_core/errors/period_mismatch_error.py b/openfisca_core/errors/period_mismatch_error.py index 0ba01abcd0..fcece9474d 100644 --- a/openfisca_core/errors/period_mismatch_error.py +++ b/openfisca_core/errors/period_mismatch_error.py @@ -1,9 +1,7 @@ class PeriodMismatchError(ValueError): - """ - Exception raised when one tries to set a variable value for a period that doesn't match its definition period - """ + """Exception raised when one tries to set a variable value for a period that doesn't match its definition period.""" - def __init__(self, variable_name, period, definition_period, message): + def __init__(self, variable_name: str, period, definition_period, message) -> None: self.variable_name = variable_name self.period = period self.definition_period = definition_period diff --git a/openfisca_core/errors/situation_parsing_error.py b/openfisca_core/errors/situation_parsing_error.py index f5c11e65cb..a5d7ee88d3 100644 --- a/openfisca_core/errors/situation_parsing_error.py +++ b/openfisca_core/errors/situation_parsing_error.py @@ -1,20 +1,27 @@ +from __future__ import annotations + +from collections.abc import Iterable + import os -import dpath +import dpath.util class SituationParsingError(Exception): - """ - Exception raised when the situation provided as an input for a simulation cannot be parsed - """ + """Exception raised when the situation provided as an input for a simulation cannot be parsed.""" - def __init__(self, path, message, code = None): + def __init__( + self, + path: Iterable[str], + message: str, + code: int | None = None, + ) -> None: self.error = {} - dpath_path = '/'.join([str(item) for item in path]) - message = str(message).strip(os.linesep).replace(os.linesep, ' ') + dpath_path = "/".join([str(item) for item in path]) + message = str(message).strip(os.linesep).replace(os.linesep, " ") dpath.util.new(self.error, dpath_path, message) self.code = code Exception.__init__(self, str(self.error)) - def __str__(self): + def __str__(self) -> str: return str(self.error) diff --git a/openfisca_core/errors/spiral_error.py b/openfisca_core/errors/spiral_error.py index 0495439b68..ffa7fe2850 100644 --- a/openfisca_core/errors/spiral_error.py +++ b/openfisca_core/errors/spiral_error.py @@ -1,4 +1,2 @@ class SpiralError(Exception): """Simulation error.""" - - pass diff --git a/openfisca_core/errors/variable_name_config_error.py b/openfisca_core/errors/variable_name_config_error.py index 7a87d7f5c8..fec1c45864 100644 --- a/openfisca_core/errors/variable_name_config_error.py +++ b/openfisca_core/errors/variable_name_config_error.py @@ -1,6 +1,2 @@ class VariableNameConflictError(Exception): - """ - Exception raised when two variables with the same name are added to a tax and benefit system. - """ - - pass + """Exception raised when two variables with the same name are added to a tax and benefit system.""" diff --git a/openfisca_core/errors/variable_not_found_error.py b/openfisca_core/errors/variable_not_found_error.py index f84ce06f95..46ece4b13c 100644 --- a/openfisca_core/errors/variable_not_found_error.py +++ b/openfisca_core/errors/variable_not_found_error.py @@ -2,29 +2,28 @@ class VariableNotFoundError(Exception): - """ - Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem. - """ + """Exception raised when a variable has been queried but is not defined in the TaxBenefitSystem.""" - def __init__(self, variable_name, tax_benefit_system): - """ - :param variable_name: Name of the variable that was queried. + def __init__(self, variable_name: str, tax_benefit_system) -> None: + """:param variable_name: Name of the variable that was queried. :param tax_benefit_system: Tax benefits system that does not contain `variable_name` """ country_package_metadata = tax_benefit_system.get_package_metadata() - country_package_name = country_package_metadata['name'] - country_package_version = country_package_metadata['version'] + country_package_name = country_package_metadata["name"] + country_package_version = country_package_metadata["version"] if country_package_version: - country_package_id = '{}@{}'.format(country_package_name, country_package_version) + country_package_id = f"{country_package_name}@{country_package_version}" else: country_package_id = country_package_name - message = os.linesep.join([ - "You tried to calculate or to set a value for variable '{0}', but it was not found in the loaded tax and benefit system ({1}).".format(variable_name, country_package_id), - "Are you sure you spelled '{0}' correctly?".format(variable_name), - "If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.", - "Look at its changelog to learn about renames and removals and update your code. If it is an official package,", - "it is probably available on .".format(country_package_name) - ]) + message = os.linesep.join( + [ + f"You tried to calculate or to set a value for variable '{variable_name}', but it was not found in the loaded tax and benefit system ({country_package_id}).", + f"Are you sure you spelled '{variable_name}' correctly?", + "If this code used to work and suddenly does not, this is most probably linked to an update of the tax and benefit system.", + "Look at its changelog to learn about renames and removals and update your code. If it is an official package,", + f"it is probably available on .", + ], + ) self.message = message self.variable_name = variable_name Exception.__init__(self, self.message) diff --git a/openfisca_core/experimental/memory_config.py b/openfisca_core/experimental/memory_config.py index 5f3b4a1126..fec38e3a54 100644 --- a/openfisca_core/experimental/memory_config.py +++ b/openfisca_core/experimental/memory_config.py @@ -4,21 +4,25 @@ class MemoryConfig: - - def __init__(self, - max_memory_occupation, - priority_variables = None, - variables_to_drop = None): + def __init__( + self, + max_memory_occupation, + priority_variables=None, + variables_to_drop=None, + ) -> None: message = [ "Memory configuration is a feature that is still currently under experimentation.", "You are very welcome to use it and send us precious feedback,", - "but keep in mind that the way it is used might change without any major version bump." - ] - warnings.warn(" ".join(message), MemoryConfigWarning) + "but keep in mind that the way it is used might change without any major version bump.", + ] + warnings.warn(" ".join(message), MemoryConfigWarning, stacklevel=2) self.max_memory_occupation = float(max_memory_occupation) if self.max_memory_occupation > 1: - raise ValueError("max_memory_occupation must be <= 1") + msg = "max_memory_occupation must be <= 1" + raise ValueError(msg) self.max_memory_occupation_pc = self.max_memory_occupation * 100 - self.priority_variables = set(priority_variables) if priority_variables else set() + self.priority_variables = ( + set(priority_variables) if priority_variables else set() + ) self.variables_to_drop = set(variables_to_drop) if variables_to_drop else set() diff --git a/openfisca_core/holders/__init__.py b/openfisca_core/holders/__init__.py index a7d46e38a6..c8422af7d5 100644 --- a/openfisca_core/holders/__init__.py +++ b/openfisca_core/holders/__init__.py @@ -21,5 +21,9 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import set_input_dispatch_by_period, set_input_divide_by_period # noqa: F401 +from .helpers import ( # noqa: F401 + set_input_dispatch_by_period, + set_input_divide_by_period, +) from .holder import Holder # noqa: F401 +from .memory_usage import MemoryUsage # noqa: F401 diff --git a/openfisca_core/holders/helpers.py b/openfisca_core/holders/helpers.py index efe16388e0..fcc6563c79 100644 --- a/openfisca_core/holders/helpers.py +++ b/openfisca_core/holders/helpers.py @@ -7,9 +7,8 @@ log = logging.getLogger(__name__) -def set_input_dispatch_by_period(holder, period, array): - """ - This function can be declared as a ``set_input`` attribute of a variable. +def set_input_dispatch_by_period(holder, period, array) -> None: + """This function can be declared as a ``set_input`` attribute of a variable. In this case, the variable will accept inputs on larger periods that its definition period, and the value for the larger period will be applied to all its subperiods. @@ -20,17 +19,19 @@ def set_input_dispatch_by_period(holder, period, array): period_size = period.size period_unit = period.unit - if holder.variable.definition_period == periods.MONTH: - cached_period_unit = periods.MONTH - elif holder.variable.definition_period == periods.YEAR: - cached_period_unit = periods.YEAR - else: - raise ValueError('set_input_dispatch_by_period can be used only for yearly or monthly variables.') + if holder.variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = "set_input_dispatch_by_period can't be used for eternal variables." + raise ValueError( + msg, + ) + cached_period_unit = holder.variable.definition_period after_instant = period.start.offset(period_size, period_unit) # Cache the input data, skipping the existing cached months - sub_period = period.start.period(cached_period_unit) + sub_period = periods.Period((cached_period_unit, period.start, 1)) while sub_period.start < after_instant: existing_array = holder.get_array(sub_period) if existing_array is None: @@ -42,9 +43,8 @@ def set_input_dispatch_by_period(holder, period, array): sub_period = sub_period.offset(1) -def set_input_divide_by_period(holder, period, array): - """ - This function can be declared as a ``set_input`` attribute of a variable. +def set_input_divide_by_period(holder, period, array) -> None: + """This function can be declared as a ``set_input`` attribute of a variable. In this case, the variable will accept inputs on larger periods that its definition period, and the value for the larger period will be divided between its subperiods. @@ -55,18 +55,20 @@ def set_input_divide_by_period(holder, period, array): period_size = period.size period_unit = period.unit - if holder.variable.definition_period == periods.MONTH: - cached_period_unit = periods.MONTH - elif holder.variable.definition_period == periods.YEAR: - cached_period_unit = periods.YEAR - else: - raise ValueError('set_input_divide_by_period can be used only for yearly or monthly variables.') + if holder.variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = "set_input_divide_by_period can't be used for eternal variables." + raise ValueError( + msg, + ) + cached_period_unit = holder.variable.definition_period after_instant = period.start.offset(period_size, period_unit) # Count the number of elementary periods to change, and the difference with what is already known. remaining_array = array.copy() - sub_period = period.start.period(cached_period_unit) + sub_period = periods.Period((cached_period_unit, period.start, 1)) sub_periods_count = 0 while sub_period.start < after_instant: existing_array = holder.get_array(sub_period) @@ -79,10 +81,13 @@ def set_input_divide_by_period(holder, period, array): # Cache the input data if sub_periods_count > 0: divided_array = remaining_array / sub_periods_count - sub_period = period.start.period(cached_period_unit) + sub_period = periods.Period((cached_period_unit, period.start, 1)) while sub_period.start < after_instant: if holder.get_array(sub_period) is None: holder._set(sub_period, divided_array) sub_period = sub_period.offset(1) elif not (remaining_array == 0).all(): - raise ValueError("Inconsistent input: variable {0} has already been set for all months contained in period {1}, and value {2} provided for {1} doesn't match the total ({3}). This error may also be thrown if you try to call set_input twice for the same variable and period.".format(holder.variable.name, period, array, array - remaining_array)) + msg = f"Inconsistent input: variable {holder.variable.name} has already been set for all months contained in period {period}, and value {array} provided for {period} doesn't match the total ({array - remaining_array}). This error may also be thrown if you try to call set_input twice for the same variable and period." + raise ValueError( + msg, + ) diff --git a/openfisca_core/holders/holder.py b/openfisca_core/holders/holder.py index 3d0379d22d..7183d4a44a 100644 --- a/openfisca_core/holders/holder.py +++ b/openfisca_core/holders/holder.py @@ -1,79 +1,87 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + import os import warnings import numpy import psutil -from openfisca_core import commons, periods, tools -from openfisca_core.errors import PeriodMismatchError -from openfisca_core.data_storage import InMemoryStorage, OnDiskStorage -from openfisca_core.indexed_enums import Enum +from openfisca_core import ( + commons, + data_storage as storage, + errors, + indexed_enums as enums, + periods, + types, +) + +from .memory_usage import MemoryUsage class Holder: - """ - A holder keeps tracks of a variable values after they have been calculated, or set as an input. - """ + """A holder keeps tracks of a variable values after they have been calculated, or set as an input.""" - def __init__(self, variable, population): + def __init__(self, variable, population) -> None: self.population = population self.variable = variable self.simulation = population.simulation - self._memory_storage = InMemoryStorage(is_eternal = (self.variable.definition_period == periods.ETERNITY)) + self._eternal = self.variable.definition_period == periods.DateUnit.ETERNITY + self._memory_storage = storage.InMemoryStorage(is_eternal=self._eternal) # By default, do not activate on-disk storage, or variable dropping self._disk_storage = None self._on_disk_storable = False self._do_not_store = False if self.simulation and self.simulation.memory_config: - if self.variable.name not in self.simulation.memory_config.priority_variables: + if ( + self.variable.name + not in self.simulation.memory_config.priority_variables + ): self._disk_storage = self.create_disk_storage() self._on_disk_storable = True if self.variable.name in self.simulation.memory_config.variables_to_drop: self._do_not_store = True def clone(self, population): - """ - Copy the holder just enough to be able to run a new simulation without modifying the original simulation. - """ + """Copy the holder just enough to be able to run a new simulation without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ('population', 'formula', 'simulation'): + if key not in ("population", "formula", "simulation"): new_dict[key] = value - new_dict['population'] = population - new_dict['simulation'] = population.simulation + new_dict["population"] = population + new_dict["simulation"] = population.simulation return new - def create_disk_storage(self, directory = None, preserve = False): + def create_disk_storage(self, directory=None, preserve=False): if directory is None: directory = self.simulation.data_storage_dir storage_dir = os.path.join(directory, self.variable.name) if not os.path.isdir(storage_dir): os.mkdir(storage_dir) - return OnDiskStorage( + return storage.OnDiskStorage( storage_dir, - is_eternal = (self.variable.definition_period == periods.ETERNITY), - preserve_storage_dir = preserve - ) + self._eternal, + preserve_storage_dir=preserve, + ) - def delete_arrays(self, period = None): - """ - If ``period`` is ``None``, remove all known values of the variable. + def delete_arrays(self, period=None) -> None: + """If ``period`` is ``None``, remove all known values of the variable. If ``period`` is not ``None``, only remove all values for any period included in period (e.g. if period is "2017", values for "2017-01", "2017-07", etc. would be removed) """ - self._memory_storage.delete(period) if self._disk_storage: self._disk_storage.delete(period) def get_array(self, period): - """ - Get the value of the variable for the given period. + """Get the value of the variable for the given period. If the value is not known, return ``None``. """ @@ -84,92 +92,149 @@ def get_array(self, period): return value if self._disk_storage: return self._disk_storage.get(period) + return None - def get_memory_usage(self): - """ - Get data about the virtual memory usage of the holder. - - :returns: Memory usage data - :rtype: dict - - Example: - - >>> holder.get_memory_usage() - >>> { - >>> 'nb_arrays': 12, # The holder contains the variable values for 12 different periods - >>> 'nb_cells_by_array': 100, # There are 100 entities (e.g. persons) in our simulation - >>> 'cell_size': 8, # Each value takes 8B of memory - >>> 'dtype': dtype('float64') # Each value is a float 64 - >>> 'total_nb_bytes': 10400 # The holder uses 10.4kB of virtual memory - >>> 'nb_requests': 24 # The variable has been computed 24 times - >>> 'nb_requests_by_array': 2 # Each array stored has been on average requested twice - >>> } - """ + def get_memory_usage(self) -> MemoryUsage: + """Get data about the virtual memory usage of the Holder. - usage = dict( - nb_cells_by_array = self.population.count, - dtype = self.variable.dtype, - ) + Returns: + Memory usage data. + + Examples: + >>> from pprint import pprint + + >>> from openfisca_core import ( + ... entities, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... variables, + ... ) + + >>> entity = entities.Entity("", "", "", "") + + >>> class MyVariable(variables.Variable): + ... definition_period = periods.DateUnit.YEAR + ... entity = entity + ... value_type = int + + >>> population = populations.Population(entity) + >>> variable = MyVariable() + >>> holder = Holder(variable, population) + + >>> tbs = taxbenefitsystems.TaxBenefitSystem([entity]) + >>> entities = {entity.key: population} + >>> simulation = simulations.Simulation(tbs, entities) + >>> holder.simulation = simulation + + >>> pprint(holder.get_memory_usage(), indent=3) + { 'cell_size': nan, + 'dtype': , + 'nb_arrays': 0, + 'nb_cells_by_array': 0, + 'total_nb_bytes': 0... + + """ + usage = MemoryUsage( + nb_cells_by_array=self.population.count, + dtype=self.variable.dtype, + ) usage.update(self._memory_storage.get_memory_usage()) if self.simulation.trace: nb_requests = self.simulation.tracer.get_nb_requests(self.variable.name) - usage.update(dict( - nb_requests = nb_requests, - nb_requests_by_array = nb_requests / float(usage['nb_arrays']) if usage['nb_arrays'] > 0 else numpy.nan - )) + usage.update( + { + "nb_requests": nb_requests, + "nb_requests_by_array": ( + nb_requests / float(usage["nb_arrays"]) + if usage["nb_arrays"] > 0 + else numpy.nan + ), + }, + ) return usage def get_known_periods(self): - """ - Get the list of periods the variable value is known for. - """ + """Get the list of periods the variable value is known for.""" + return list(self._memory_storage.get_known_periods()) + list( + self._disk_storage.get_known_periods() if self._disk_storage else [], + ) - return list(self._memory_storage.get_known_periods()) + list(( - self._disk_storage.get_known_periods() if self._disk_storage else [])) + def set_input( + self, + period: types.Period, + array: numpy.ndarray | Sequence[Any], + ) -> numpy.ndarray | None: + """Set a Variable's array of values of a given Period. - def set_input(self, period, array): - """ - Set a variable's value (``array``) for a given period (``period``) + Args: + period: The period at which the value is set. + array: The input value for the variable. - :param array: the input value for the variable - :param period: the period at which the value is setted + Returns: + The set input array. - Example : + Note: + If a ``set_input`` property has been set for the variable, this + method may accept inputs for periods not matching the + ``definition_period`` of the Variable. To read + more about this, check the `documentation`_. - >>> holder.set_input([12, 14], '2018-04') - >>> holder.get_array('2018-04') - >>> [12, 14] + Examples: + >>> from openfisca_core import entities, populations, variables + >>> entity = entities.Entity("", "", "", "") - If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. - """ + >>> class MyVariable(variables.Variable): + ... definition_period = periods.DateUnit.YEAR + ... entity = entity + ... value_type = int + + >>> variable = MyVariable() + >>> population = populations.Population(entity) + >>> population.count = 2 + + >>> holder = Holder(variable, population) + >>> holder.set_input("2018", numpy.array([12.5, 14])) + >>> holder.get_array("2018") + array([12, 14], dtype=int32) + + >>> holder.set_input("2018", [12.5, 14]) + >>> holder.get_array("2018") + array([12, 14], dtype=int32) + + .. _documentation: + https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period + + """ period = periods.period(period) - if period.unit == periods.ETERNITY and self.variable.definition_period != periods.ETERNITY: - error_message = os.linesep.join([ - 'Unable to set a value for variable {0} for periods.ETERNITY.', - '{0} is only defined for {1}s. Please adapt your input.', - ]).format( - self.variable.name, - self.variable.definition_period - ) - raise PeriodMismatchError( + + if period.unit == periods.DateUnit.ETERNITY and not self._eternal: + error_message = os.linesep.join( + [ + "Unable to set a value for variable {1} for {0}.", + "{1} is only defined for {2}s. Please adapt your input.", + ], + ).format( + periods.DateUnit.ETERNITY.upper(), + self.variable.name, + self.variable.definition_period, + ) + raise errors.PeriodMismatchError( self.variable.name, period, self.variable.definition_period, - error_message - ) + error_message, + ) if self.variable.is_neutralized: - warning_message = "You cannot set a value for the variable {}, as it has been neutralized. The value you provided ({}) will be ignored.".format(self.variable.name, array) - return warnings.warn( - warning_message, - Warning - ) + warning_message = f"You cannot set a value for the variable {self.variable.name}, as it has been neutralized. The value you provided ({array}) will be ignored." + return warnings.warn(warning_message, Warning, stacklevel=2) if self.variable.value_type in (float, int) and isinstance(array, str): - array = tools.eval_expression(array) + array = commons.eval_expression(array) if self.variable.set_input: return self.variable.set_input(self, period, array) return self._set(period, array) @@ -181,66 +246,80 @@ def _to_array(self, value): # 0-dim arrays are casted to scalar when they interact with float. We don't want that. value = value.reshape(1) if len(value) != self.population.count: + msg = f'Unable to set value "{value}" for variable "{self.variable.name}", as its length is {len(value)} while there are {self.population.count} {self.population.entity.plural} in the simulation.' raise ValueError( - 'Unable to set value "{}" for variable "{}", as its length is {} while there are {} {} in the simulation.' - .format(value, self.variable.name, len(value), self.population.count, self.population.entity.plural)) - if self.variable.value_type == Enum: + msg, + ) + if self.variable.value_type == enums.Enum: value = self.variable.possible_values.encode(value) if value.dtype != self.variable.dtype: try: value = value.astype(self.variable.dtype) except ValueError: + msg = f'Unable to set value "{value}" for variable "{self.variable.name}", as the variable dtype "{self.variable.dtype}" does not match the value dtype "{value.dtype}".' raise ValueError( - 'Unable to set value "{}" for variable "{}", as the variable dtype "{}" does not match the value dtype "{}".' - .format(value, self.variable.name, self.variable.dtype, value.dtype)) + msg, + ) return value - def _set(self, period, value): + def _set(self, period, value) -> None: value = self._to_array(value) - if self.variable.definition_period != periods.ETERNITY: + if not self._eternal: if period is None: - raise ValueError('A period must be specified to set values, except for variables with periods.ETERNITY as as period_definition.') - if (self.variable.definition_period != period.unit or period.size > 1): + msg = ( + f"A period must be specified to set values, except for variables with " + f"{periods.DateUnit.ETERNITY.upper()} as as period_definition." + ) + raise ValueError( + msg, + ) + if self.variable.definition_period != period.unit or period.size > 1: name = self.variable.name - period_size_adj = f'{period.unit}' if (period.size == 1) else f'{period.size}-{period.unit}s' - error_message = os.linesep.join([ - f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".', - f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.', - f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.' - ]) - - raise PeriodMismatchError( + period_size_adj = ( + f"{period.unit}" + if (period.size == 1) + else f"{period.size}-{period.unit}s" + ) + error_message = os.linesep.join( + [ + f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".', + f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.', + f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.', + ], + ) + + raise errors.PeriodMismatchError( self.variable.name, period, self.variable.definition_period, - error_message - ) + error_message, + ) should_store_on_disk = ( - self._on_disk_storable and - self._memory_storage.get(period) is None and # If there is already a value in memory, replace it and don't put a new value in the disk storage - psutil.virtual_memory().percent >= self.simulation.memory_config.max_memory_occupation_pc - ) + self._on_disk_storable + and self._memory_storage.get(period) is None + and psutil.virtual_memory().percent # If there is already a value in memory, replace it and don't put a new value in the disk storage + >= self.simulation.memory_config.max_memory_occupation_pc + ) if should_store_on_disk: self._disk_storage.put(value, period) else: self._memory_storage.put(value, period) - def put_in_cache(self, value, period): + def put_in_cache(self, value, period) -> None: if self._do_not_store: return - if (self.simulation.opt_out_cache and - self.simulation.tax_benefit_system.cache_blacklist and - self.variable.name in self.simulation.tax_benefit_system.cache_blacklist): + if ( + self.simulation.opt_out_cache + and self.simulation.tax_benefit_system.cache_blacklist + and self.variable.name in self.simulation.tax_benefit_system.cache_blacklist + ): return self._set(period, value) def default_array(self): - """ - Return a new array of the appropriate length for the entity, filled with the variable default values. - """ - + """Return a new array of the appropriate length for the entity, filled with the variable default values.""" return self.variable.default_array(self.population.count) diff --git a/openfisca_core/holders/memory_usage.py b/openfisca_core/holders/memory_usage.py new file mode 100644 index 0000000000..2d344318ee --- /dev/null +++ b/openfisca_core/holders/memory_usage.py @@ -0,0 +1,26 @@ +from typing_extensions import TypedDict + +import numpy + + +class MemoryUsage(TypedDict, total=False): + """Virtual memory usage of a Holder. + + Attributes: + cell_size: The amount of bytes assigned to each value. + dtype: The :mod:`numpy.dtype` of any, each, and every value. + nb_arrays: The number of periods for which the Holder contains values. + nb_cells_by_array: The number of entities in the current Simulation. + nb_requests: The number of times the Variable has been computed. + nb_requests_by_array: Average times a stored array has been read. + total_nb_bytes: The total number of bytes used by the Holder. + + """ + + cell_size: int + dtype: numpy.dtype + nb_arrays: int + nb_cells_by_array: int + nb_requests: int + nb_requests_by_array: int + total_nb_bytes: int diff --git a/openfisca_core/holders/tests/__init__.py b/openfisca_core/holders/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/holders/tests/test_helpers.py b/openfisca_core/holders/tests/test_helpers.py new file mode 100644 index 0000000000..948f25288f --- /dev/null +++ b/openfisca_core/holders/tests/test_helpers.py @@ -0,0 +1,134 @@ +import pytest + +from openfisca_core import holders, tools +from openfisca_core.entities import Entity +from openfisca_core.holders import Holder +from openfisca_core.periods import DateUnit, Instant, Period +from openfisca_core.populations import Population +from openfisca_core.variables import Variable + + +@pytest.fixture +def people(): + return Entity( + key="person", + plural="people", + label="An individual member of a larger group.", + doc="People have the particularity of not being someone else.", + ) + + +@pytest.fixture +def Income(people): + return type( + "Income", + (Variable,), + {"value_type": float, "entity": people}, + ) + + +@pytest.fixture +def population(people): + population = Population(people) + population.count = 1 + return population + + +@pytest.mark.parametrize( + ("dispatch_unit", "definition_unit", "values", "expected"), + [ + (DateUnit.YEAR, DateUnit.YEAR, [1.0], [3.0]), + (DateUnit.YEAR, DateUnit.MONTH, [1.0], [36.0]), + (DateUnit.YEAR, DateUnit.DAY, [1.0], [1096.0]), + (DateUnit.YEAR, DateUnit.WEEK, [1.0], [157.0]), + (DateUnit.YEAR, DateUnit.WEEKDAY, [1.0], [1096.0]), + (DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.MONTH, DateUnit.MONTH, [1.0], [3.0]), + (DateUnit.MONTH, DateUnit.DAY, [1.0], [90.0]), + (DateUnit.MONTH, DateUnit.WEEK, [1.0], [13.0]), + (DateUnit.MONTH, DateUnit.WEEKDAY, [1.0], [90.0]), + (DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.DAY, [1.0], [3.0]), + (DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEKDAY, [1.0], [3.0]), + (DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.DAY, [1.0], [21.0]), + (DateUnit.WEEK, DateUnit.WEEK, [1.0], [3.0]), + (DateUnit.WEEK, DateUnit.WEEKDAY, [1.0], [21.0]), + (DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.DAY, [1.0], [3.0]), + (DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEKDAY, [1.0], [3.0]), + ], +) +def test_set_input_dispatch_by_period( + Income, + population, + dispatch_unit, + definition_unit, + values, + expected, +) -> None: + Income.definition_period = definition_unit + income = Income() + holder = Holder(income, population) + instant = Instant((2022, 1, 1)) + dispatch_period = Period((dispatch_unit, instant, 3)) + + holders.set_input_dispatch_by_period(holder, dispatch_period, values) + total = sum(map(holder.get_array, holder.get_known_periods())) + + tools.assert_near(total, expected, absolute_error_margin=0.001) + + +@pytest.mark.parametrize( + ("divide_unit", "definition_unit", "values", "expected"), + [ + (DateUnit.YEAR, DateUnit.YEAR, [3.0], [1.0]), + (DateUnit.YEAR, DateUnit.MONTH, [36.0], [1.0]), + (DateUnit.YEAR, DateUnit.DAY, [1095.0], [1.0]), + (DateUnit.YEAR, DateUnit.WEEK, [157.0], [1.0]), + (DateUnit.YEAR, DateUnit.WEEKDAY, [1095.0], [1.0]), + (DateUnit.MONTH, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.MONTH, DateUnit.MONTH, [3.0], [1.0]), + (DateUnit.MONTH, DateUnit.DAY, [90.0], [1.0]), + (DateUnit.MONTH, DateUnit.WEEK, [13.0], [1.0]), + (DateUnit.MONTH, DateUnit.WEEKDAY, [90.0], [1.0]), + (DateUnit.DAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.DAY, [3.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.DAY, DateUnit.WEEKDAY, [3.0], [1.0]), + (DateUnit.WEEK, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEK, DateUnit.DAY, [21.0], [1.0]), + (DateUnit.WEEK, DateUnit.WEEK, [3.0], [1.0]), + (DateUnit.WEEK, DateUnit.WEEKDAY, [21.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.YEAR, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.MONTH, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.DAY, [3.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEK, [1.0], [1.0]), + (DateUnit.WEEKDAY, DateUnit.WEEKDAY, [3.0], [1.0]), + ], +) +def test_set_input_divide_by_period( + Income, + population, + divide_unit, + definition_unit, + values, + expected, +) -> None: + Income.definition_period = definition_unit + income = Income() + holder = Holder(income, population) + instant = Instant((2022, 1, 1)) + divide_period = Period((divide_unit, instant, 3)) + + holders.set_input_divide_by_period(holder, divide_period, values) + last = holder.get_array(holder.get_known_periods()[-1]) + + tools.assert_near(last, expected, absolute_error_margin=0.001) diff --git a/openfisca_core/indexed_enums/__init__.py b/openfisca_core/indexed_enums/__init__.py index a6218c9fb0..de0dff6528 100644 --- a/openfisca_core/indexed_enums/__init__.py +++ b/openfisca_core/indexed_enums/__init__.py @@ -45,7 +45,7 @@ # Official Public API from .config import ENUM_ARRAY_DTYPE # noqa: F401 -from .enum_array import EnumArray # noqa: F401 from .enum import Enum # noqa: F401 +from .enum_array import EnumArray # noqa: F401 __all__ = ["ENUM_ARRAY_DTYPE", "EnumArray", "Enum"] diff --git a/openfisca_core/indexed_enums/enum.py b/openfisca_core/indexed_enums/enum.py index 8d76f604f5..15e9942854 100644 --- a/openfisca_core/indexed_enums/enum.py +++ b/openfisca_core/indexed_enums/enum.py @@ -1,12 +1,13 @@ from __future__ import annotations -import enum from typing import Any, List, Union -import numpy - from openfisca_core.types import ArrayType, SupportsEncode +import enum + +import numpy + from .. import indexed_enums as enums from .enum_array import EnumArray @@ -16,7 +17,7 @@ ArrayType[bytes], ArrayType[int], ArrayType[str], - ] +] class Enum(enum.Enum): @@ -98,7 +99,7 @@ class Enum(enum.Enum): value: Any def __init__(self, name: str) -> None: - """ Tweaks :class:`~enum.Enum` to add an index to each enum item. + """Tweaks :class:`~enum.Enum` to add an index to each enum item. When the enum item is initialized, ``self._member_names_`` contains the names of the previously initialized items, so its length is the index @@ -243,9 +244,11 @@ def encode(cls, array: A) -> EnumArray: return array - if isinstance(array, numpy.ndarray) and \ - array.size > 0 and \ - isinstance(array.take(0), (Enum, bytes, str)): + if ( + isinstance(array, numpy.ndarray) + and array.size > 0 + and isinstance(array.take(0), (Enum, bytes, str)) + ): if numpy.issubdtype(array.dtype, bytes): @@ -277,9 +280,6 @@ def encode(cls, array: A) -> EnumArray: choices = [item.index for item in cls] - array = \ - numpy \ - .select(conditions, choices) \ - .astype(enums.ENUM_ARRAY_DTYPE) + array = numpy.select(conditions, choices).astype(enums.ENUM_ARRAY_DTYPE) return EnumArray(array, cls) diff --git a/openfisca_core/indexed_enums/enum_array.py b/openfisca_core/indexed_enums/enum_array.py index 4aed222835..14a106636d 100644 --- a/openfisca_core/indexed_enums/enum_array.py +++ b/openfisca_core/indexed_enums/enum_array.py @@ -2,10 +2,10 @@ from typing import Any, NoReturn, Optional, Type, Union -import numpy - from openfisca_core.types import ArrayLike, ArrayType, SupportsEncode +import numpy + class EnumArray(numpy.ndarray): """:class:`numpy.ndarray` subclass representing an array of :class:`.Enum`. @@ -70,10 +70,10 @@ class EnumArray(numpy.ndarray): """ def __new__( - cls, - input_array: ArrayType[int], - possible_values: Optional[Type[SupportsEncode]] = None, - ) -> EnumArray: + cls, + input_array: ArrayType[int], + possible_values: Optional[Type[SupportsEncode]] = None, + ) -> EnumArray: """See comment above.""" obj: EnumArray @@ -184,7 +184,7 @@ def _forbidden_operation(self, other: Any) -> NoReturn: raise TypeError( "Forbidden operation. The only operations allowed on " f"{self.__class__.__name__}s are '==' and '!='.", - ) + ) __add__ = _forbidden_operation __mul__ = _forbidden_operation @@ -218,7 +218,7 @@ def decode(self) -> ArrayLike[SupportsEncode]: return numpy.select( [self == item.index for item in self.possible_values], list(self.possible_values), - ) + ) def decode_to_str(self) -> ArrayType[str]: """Decodes itself to an array of strings. @@ -243,4 +243,4 @@ def decode_to_str(self) -> ArrayType[str]: return numpy.select( [self == item.index for item in self.possible_values], [item.name for item in self.possible_values], - ) + ) diff --git a/openfisca_core/indexed_enums/tests/test_enum_array.py b/openfisca_core/indexed_enums/tests/test_enum_array.py index dab6b29047..5ccc5256fe 100644 --- a/openfisca_core/indexed_enums/tests/test_enum_array.py +++ b/openfisca_core/indexed_enums/tests/test_enum_array.py @@ -1,7 +1,7 @@ import numpy import pytest -from openfisca_core.indexed_enums import EnumArray, Enum +from openfisca_core.indexed_enums import Enum, EnumArray class MyEnum(Enum): @@ -33,5 +33,5 @@ def test_enum_array_ne_operation(enum_array): def test_enum_array_any_other_operation(enum_array): """Only equality and non-equality operations are permitted.""" - with pytest.raises(TypeError, match = "Forbidden operation."): + with pytest.raises(TypeError, match="Forbidden operation."): enum_array * 1 diff --git a/openfisca_core/indexed_enums/types.py b/openfisca_core/indexed_enums/types.py new file mode 100644 index 0000000000..43c38780ff --- /dev/null +++ b/openfisca_core/indexed_enums/types.py @@ -0,0 +1,3 @@ +from openfisca_core.types import Array + +__all__ = ["Array"] diff --git a/openfisca_core/model_api.py b/openfisca_core/model_api.py index 8ccf5c2763..e36e0d5f76 100644 --- a/openfisca_core/model_api.py +++ b/openfisca_core/model_api.py @@ -1,39 +1,63 @@ -from datetime import date # noqa: F401 +from datetime import date -from numpy import ( # noqa: F401 +from numpy import ( logical_not as not_, maximum as max_, minimum as min_, round as round_, select, where, - ) +) -from openfisca_core.commons import apply_thresholds, concat, switch # noqa: F401 - -from openfisca_core.holders import ( # noqa: F401 +from openfisca_core.commons import apply_thresholds, concat, switch +from openfisca_core.holders import ( set_input_dispatch_by_period, set_input_divide_by_period, - ) - -from openfisca_core.indexed_enums import Enum # noqa: F401 - -from openfisca_core.parameters import ( # noqa: F401 - load_parameter_file, - ParameterNode, - Scale, +) +from openfisca_core.indexed_enums import Enum +from openfisca_core.parameters import ( Bracket, Parameter, + ParameterNode, + Scale, ValuesHistory, - ) - -from openfisca_core.periods import DAY, MONTH, YEAR, ETERNITY, period # noqa: F401 -from openfisca_core.populations import ADD, DIVIDE # noqa: F401 -from openfisca_core.reforms import Reform # noqa: F401 - -from openfisca_core.simulations import ( # noqa: F401 - calculate_output_add, - calculate_output_divide, - ) - -from openfisca_core.variables import Variable # noqa: F401 + load_parameter_file, +) +from openfisca_core.periods import DAY, ETERNITY, MONTH, YEAR, period +from openfisca_core.populations import ADD, DIVIDE +from openfisca_core.reforms import Reform +from openfisca_core.simulations import calculate_output_add, calculate_output_divide +from openfisca_core.variables import Variable + +__all__ = [ + "date", + "not_", + "max_", + "min_", + "round_", + "select", + "where", + "apply_thresholds", + "concat", + "switch", + "set_input_dispatch_by_period", + "set_input_divide_by_period", + "Enum", + "Bracket", + "Parameter", + "ParameterNode", + "Scale", + "ValuesHistory", + "load_parameter_file", + "DAY", + "ETERNITY", + "MONTH", + "YEAR", + "period", + "ADD", + "DIVIDE", + "Reform", + "calculate_output_add", + "calculate_output_divide", + "Variable", +] diff --git a/openfisca_core/parameters/__init__.py b/openfisca_core/parameters/__init__.py index 040ae47056..5d742d4611 100644 --- a/openfisca_core/parameters/__init__.py +++ b/openfisca_core/parameters/__init__.py @@ -21,24 +21,52 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import ParameterNotFound, ParameterParsingError # noqa: F401 +from openfisca_core.errors import ParameterNotFound, ParameterParsingError - -from .config import ( # noqa: F401 +from .at_instant_like import AtInstantLike +from .config import ( ALLOWED_PARAM_TYPES, COMMON_KEYS, FILE_EXTENSIONS, date_constructor, dict_no_duplicate_constructor, - ) +) +from .helpers import contains_nan, load_parameter_file +from .parameter import Parameter +from .parameter_at_instant import ParameterAtInstant +from .parameter_node import ParameterNode +from .parameter_node_at_instant import ParameterNodeAtInstant +from .parameter_scale import ParameterScale, ParameterScale as Scale +from .parameter_scale_bracket import ( + ParameterScaleBracket, + ParameterScaleBracket as Bracket, +) +from .values_history import ValuesHistory +from .vectorial_asof_date_parameter_node_at_instant import ( + VectorialAsofDateParameterNodeAtInstant, +) +from .vectorial_parameter_node_at_instant import VectorialParameterNodeAtInstant -from .at_instant_like import AtInstantLike # noqa: F401 -from .helpers import contains_nan, load_parameter_file # noqa: F401 -from .parameter_at_instant import ParameterAtInstant # noqa: F401 -from .parameter_node_at_instant import ParameterNodeAtInstant # noqa: F401 -from .vectorial_parameter_node_at_instant import VectorialParameterNodeAtInstant # noqa: F401 -from .parameter import Parameter # noqa: F401 -from .parameter_node import ParameterNode # noqa: F401 -from .parameter_scale import ParameterScale, ParameterScale as Scale # noqa: F401 -from .parameter_scale_bracket import ParameterScaleBracket, ParameterScaleBracket as Bracket # noqa: F401 -from .values_history import ValuesHistory # noqa: F401 +__all__ = [ + "ParameterNotFound", + "ParameterParsingError", + "AtInstantLike", + "ALLOWED_PARAM_TYPES", + "COMMON_KEYS", + "FILE_EXTENSIONS", + "date_constructor", + "dict_no_duplicate_constructor", + "contains_nan", + "load_parameter_file", + "Parameter", + "ParameterAtInstant", + "ParameterNode", + "ParameterNodeAtInstant", + "ParameterScale", + "Scale", + "ParameterScaleBracket", + "Bracket", + "ValuesHistory", + "VectorialAsofDateParameterNodeAtInstant", + "VectorialParameterNodeAtInstant", +] diff --git a/openfisca_core/parameters/at_instant_like.py b/openfisca_core/parameters/at_instant_like.py index 1a1db34beb..19c28e98c2 100644 --- a/openfisca_core/parameters/at_instant_like.py +++ b/openfisca_core/parameters/at_instant_like.py @@ -4,9 +4,7 @@ class AtInstantLike(abc.ABC): - """ - Base class for various types of parameters implementing the at instant protocol. - """ + """Base class for various types of parameters implementing the at instant protocol.""" def __call__(self, instant): return self.get_at_instant(instant) @@ -16,5 +14,4 @@ def get_at_instant(self, instant): return self._get_at_instant(instant) @abc.abstractmethod - def _get_at_instant(self, instant): - ... + def _get_at_instant(self, instant): ... diff --git a/openfisca_core/parameters/config.py b/openfisca_core/parameters/config.py index e9a3041ae8..5fb1198bea 100644 --- a/openfisca_core/parameters/config.py +++ b/openfisca_core/parameters/config.py @@ -1,9 +1,9 @@ -import warnings import os +import warnings + import yaml -import typing -from openfisca_core.warnings import LibYAMLWarning +from openfisca_core.warnings import LibYAMLWarning try: from yaml import CLoader as Loader @@ -12,33 +12,44 @@ "libyaml is not installed in your environment.", "This can make OpenFisca slower to start.", "Once you have installed libyaml, run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", - "so that it is used in your Python environment." + os.linesep - ] - warnings.warn(" ".join(message), LibYAMLWarning) - from yaml import Loader # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) + "so that it is used in your Python environment." + os.linesep, + ] + warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) + from yaml import ( # type: ignore # (see https://github.com/python/mypy/issues/1153#issuecomment-455802270) + Loader, + ) # 'unit' and 'reference' are only listed here for backward compatibility. # It is now recommended to include them in metadata, until a common consensus emerges. -ALLOWED_PARAM_TYPES = (float, int, bool, type(None), typing.List) -COMMON_KEYS = {'description', 'metadata', 'unit', 'reference', 'documentation'} -FILE_EXTENSIONS = {'.yaml', '.yml'} +ALLOWED_PARAM_TYPES = (float, int, bool, type(None), list) +COMMON_KEYS = {"description", "metadata", "unit", "reference", "documentation"} +FILE_EXTENSIONS = {".yaml", ".yml"} def date_constructor(_loader, node): return node.value -yaml.add_constructor('tag:yaml.org,2002:timestamp', date_constructor, Loader = Loader) +yaml.add_constructor("tag:yaml.org,2002:timestamp", date_constructor, Loader=Loader) -def dict_no_duplicate_constructor(loader, node, deep = False): +def dict_no_duplicate_constructor(loader, node, deep=False): keys = [key.value for key, value in node.value] if len(keys) != len(set(keys)): - duplicate = next((key for key in keys if keys.count(key) > 1)) - raise yaml.parser.ParserError('', node.start_mark, f"Found duplicate key '{duplicate}'") + duplicate = next(key for key in keys if keys.count(key) > 1) + msg = "" + raise yaml.parser.ParserError( + msg, + node.start_mark, + f"Found duplicate key '{duplicate}'", + ) return loader.construct_mapping(node, deep) -yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, dict_no_duplicate_constructor, Loader = Loader) +yaml.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, + dict_no_duplicate_constructor, + Loader=Loader, +) diff --git a/openfisca_core/parameters/helpers.py b/openfisca_core/parameters/helpers.py index 75d5a18b73..09925bbcdb 100644 --- a/openfisca_core/parameters/helpers.py +++ b/openfisca_core/parameters/helpers.py @@ -9,91 +9,98 @@ def contains_nan(vector): - if numpy.issubdtype(vector.dtype, numpy.record): - return any([contains_nan(vector[name]) for name in vector.dtype.names]) - else: - return numpy.isnan(vector).any() + if numpy.issubdtype(vector.dtype, numpy.record) or numpy.issubdtype( + vector.dtype, + numpy.void, + ): + return any(contains_nan(vector[name]) for name in vector.dtype.names) + return numpy.isnan(vector).any() -def load_parameter_file(file_path, name = ''): - """ - Load parameters from a YAML file (or a directory containing YAML files). +def load_parameter_file(file_path, name=""): + """Load parameters from a YAML file (or a directory containing YAML files). :returns: An instance of :class:`.ParameterNode` or :class:`.ParameterScale` or :class:`.Parameter`. """ if not os.path.exists(file_path): - raise ValueError("{} does not exist".format(file_path)) + msg = f"{file_path} does not exist" + raise ValueError(msg) if os.path.isdir(file_path): - return parameters.ParameterNode(name, directory_path = file_path) + return parameters.ParameterNode(name, directory_path=file_path) data = _load_yaml_file(file_path) return _parse_child(name, data, file_path) -def _compose_name(path, child_name = None, item_name = None): +def _compose_name(path, child_name=None, item_name=None): if not path: return child_name if child_name is not None: - return '{}.{}'.format(path, child_name) + return f"{path}.{child_name}" if item_name is not None: - return '{}[{}]'.format(path, item_name) + return f"{path}[{item_name}]" + return None def _load_yaml_file(file_path): - with open(file_path, 'r') as f: + with open(file_path) as f: try: - return config.yaml.load(f, Loader = config.Loader) + return config.yaml.load(f, Loader=config.Loader) except (config.yaml.scanner.ScannerError, config.yaml.parser.ParserError): stack_trace = traceback.format_exc() + msg = "Invalid YAML. Check the traceback above for more details." raise ParameterParsingError( - "Invalid YAML. Check the traceback above for more details.", + msg, file_path, - stack_trace - ) + stack_trace, + ) except Exception: stack_trace = traceback.format_exc() + msg = "Invalid parameter file content. Check the traceback above for more details." raise ParameterParsingError( - "Invalid parameter file content. Check the traceback above for more details.", + msg, file_path, - stack_trace - ) + stack_trace, + ) def _parse_child(child_name, child, child_path): - if 'values' in child: + if "values" in child: return parameters.Parameter(child_name, child, child_path) - elif 'brackets' in child: + if "brackets" in child: return parameters.ParameterScale(child_name, child, child_path) - elif isinstance(child, dict) and all([periods.INSTANT_PATTERN.match(str(key)) for key in child.keys()]): + if isinstance(child, dict) and all( + periods.INSTANT_PATTERN.match(str(key)) for key in child + ): return parameters.Parameter(child_name, child, child_path) - else: - return parameters.ParameterNode(child_name, data = child, file_path = child_path) + return parameters.ParameterNode(child_name, data=child, file_path=child_path) -def _set_backward_compatibility_metadata(parameter, data): - if data.get('unit') is not None: - parameter.metadata['unit'] = data['unit'] - if data.get('reference') is not None: - parameter.metadata['reference'] = data['reference'] +def _set_backward_compatibility_metadata(parameter, data) -> None: + if data.get("unit") is not None: + parameter.metadata["unit"] = data["unit"] + if data.get("reference") is not None: + parameter.metadata["reference"] = data["reference"] -def _validate_parameter(parameter, data, data_type = None, allowed_keys = None): +def _validate_parameter(parameter, data, data_type=None, allowed_keys=None) -> None: type_map = { - dict: 'object', - list: 'array', - } + dict: "object", + list: "array", + } if data_type is not None and not isinstance(data, data_type): + msg = f"'{parameter.name}' must be of type {type_map[data_type]}." raise ParameterParsingError( - "'{}' must be of type {}.".format(parameter.name, type_map[data_type]), - parameter.file_path - ) + msg, + parameter.file_path, + ) if allowed_keys is not None: keys = data.keys() for key in keys: if key not in allowed_keys: + msg = f"Unexpected property '{key}' in '{parameter.name}'. Allowed properties are {list(allowed_keys)}." raise ParameterParsingError( - "Unexpected property '{}' in '{}'. Allowed properties are {}." - .format(key, parameter.name, list(allowed_keys)), - parameter.file_path - ) + msg, + parameter.file_path, + ) diff --git a/openfisca_core/parameters/parameter.py b/openfisca_core/parameters/parameter.py index 62fd3f6766..528f54cccd 100644 --- a/openfisca_core/parameters/parameter.py +++ b/openfisca_core/parameters/parameter.py @@ -1,22 +1,30 @@ +from __future__ import annotations + import copy import os -import typing from openfisca_core import commons, periods from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import config, helpers, AtInstantLike, ParameterAtInstant + +from . import config, helpers +from .at_instant_like import AtInstantLike +from .parameter_at_instant import ParameterAtInstant class Parameter(AtInstantLike): - """ - A parameter of the legislation. Parameters can change over time. + """A parameter of the legislation. + + Parameters can change over time. - :param string name: Name of the parameter, e.g. "taxes.some_tax.some_param" - :param dict data: Data loaded from a YAML file. - :param string file_path: File the parameter was loaded from. - :param string documentation: Documentation describing parameter usage and context. + Attributes: + values_list: List of the values, in reverse chronological order. + Args: + name: Name of the parameter, e.g. "taxes.some_tax.some_param". + data: Data loaded from a YAML file. + file_path: File the parameter was loaded from. + Instantiate a parameter without metadata: >>> Parameter('rate', data = { @@ -34,63 +42,84 @@ class Parameter(AtInstantLike): } }) - .. attribute:: values_list - - List of the values, in reverse chronological order """ - def __init__(self, name, data, file_path = None): + def __init__(self, name: str, data: dict, file_path: str | None = None) -> None: self.name: str = name - self.file_path: str = file_path - helpers._validate_parameter(self, data, data_type = dict) - self.description: str = None - self.metadata: typing.Dict = {} - self.documentation: str = None + self.file_path: str | None = file_path + helpers._validate_parameter(self, data, data_type=dict) + self.description: str | None = None + self.metadata: dict = {} + self.documentation: str | None = None self.values_history = self # Only for backward compatibility # Normal parameter declaration: the values are declared under the 'values' key: parse the description and metadata. - if data.get('values'): + if data.get("values"): # 'unit' and 'reference' are only listed here for backward compatibility - helpers._validate_parameter(self, data, allowed_keys = config.COMMON_KEYS.union({'values'})) - self.description = data.get('description') + helpers._validate_parameter( + self, + data, + allowed_keys=config.COMMON_KEYS.union({"values"}), + ) + self.description = data.get("description") helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) - helpers._validate_parameter(self, data['values'], data_type = dict) - values = data['values'] + helpers._validate_parameter(self, data["values"], data_type=dict) + values = data["values"] - self.documentation = data.get('documentation') + self.documentation = data.get("documentation") else: # Simplified parameter declaration: only values are provided values = data - instants = sorted(values.keys(), reverse = True) # sort in reverse chronological order + instants = sorted( + values.keys(), + reverse=True, + ) # sort in reverse chronological order values_list = [] for instant_str in instants: if not periods.INSTANT_PATTERN.match(instant_str): + msg = f"Invalid property '{instant_str}' in '{self.name}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15." raise ParameterParsingError( - "Invalid property '{}' in '{}'. Properties must be valid YYYY-MM-DD instants, such as 2017-01-15." - .format(instant_str, self.name), - file_path) + msg, + file_path, + ) instant_info = values[instant_str] # Ignore expected values, as they are just metadata - if instant_info == "expected" or isinstance(instant_info, dict) and instant_info.get("expected"): + if ( + instant_info == "expected" + or isinstance(instant_info, dict) + and instant_info.get("expected") + ): continue - value_name = helpers._compose_name(name, item_name = instant_str) - value_at_instant = ParameterAtInstant(value_name, instant_str, data = instant_info, file_path = self.file_path, metadata = self.metadata) + value_name = helpers._compose_name(name, item_name=instant_str) + value_at_instant = ParameterAtInstant( + value_name, + instant_str, + data=instant_info, + file_path=self.file_path, + metadata=self.metadata, + ) values_list.append(value_at_instant) - self.values_list: typing.List[ParameterAtInstant] = values_list + self.values_list: list[ParameterAtInstant] = values_list - def __repr__(self): - return os.linesep.join([ - '{}: {}'.format(value.instant_str, value.value if value.value is not None else 'null') for value in self.values_list - ]) + def __repr__(self) -> str: + return os.linesep.join( + [ + "{}: {}".format( + value.instant_str, + value.value if value.value is not None else "null", + ) + for value in self.values_list + ], + ) def __eq__(self, other): return (self.name == other.name) and (self.values_list == other.values_list) @@ -100,12 +129,13 @@ def clone(self): clone.__dict__ = self.__dict__.copy() clone.metadata = copy.deepcopy(self.metadata) - clone.values_list = [parameter_at_instant.clone() for parameter_at_instant in self.values_list] + clone.values_list = [ + parameter_at_instant.clone() for parameter_at_instant in self.values_list + ] return clone - def update(self, period = None, start = None, stop = None, value = None): - """ - Change the value for a given period. + def update(self, period=None, start=None, stop=None, value=None) -> None: + """Change the value for a given period. :param period: Period where the value is modified. If set, `start` and `stop` should be `None`. :param start: Start of the period. Instance of `openfisca_core.periods.Instant`. If set, `period` should be `None`. @@ -114,15 +144,19 @@ def update(self, period = None, start = None, stop = None, value = None): """ if period is not None: if start is not None or stop is not None: - raise TypeError("Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'.") + msg = "Wrong input for 'update' method: use either 'update(period, value = value)' or 'update(start = start, stop = stop, value = value)'. You cannot both use 'period' and 'start' or 'stop'." + raise TypeError( + msg, + ) if isinstance(period, str): period = periods.period(period) start = period.start stop = period.stop if start is None: - raise ValueError("You must provide either a start or a period") + msg = "You must provide either a start or a period" + raise ValueError(msg) start_str = str(start) - stop_str = str(stop.offset(1, 'day')) if stop else None + stop_str = str(stop.offset(1, "day")) if stop else None old_values = self.values_list new_values = [] @@ -139,20 +173,27 @@ def update(self, period = None, start = None, stop = None, value = None): if stop_str: if new_values and (stop_str == new_values[-1].instant_str): pass # such interval is empty + elif i < n: + overlapped_value = old_values[i].value + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": overlapped_value}, + ) + new_values.append(new_interval) else: - if i < n: - overlapped_value = old_values[i].value - value_name = helpers._compose_name(self.name, item_name = stop_str) - new_interval = ParameterAtInstant(value_name, stop_str, data = {'value': overlapped_value}) - new_values.append(new_interval) - else: - value_name = helpers._compose_name(self.name, item_name = stop_str) - new_interval = ParameterAtInstant(value_name, stop_str, data = {'value': None}) - new_values.append(new_interval) + value_name = helpers._compose_name(self.name, item_name=stop_str) + new_interval = ParameterAtInstant( + value_name, + stop_str, + data={"value": None}, + ) + new_values.append(new_interval) # Insert new interval - value_name = helpers._compose_name(self.name, item_name = start_str) - new_interval = ParameterAtInstant(value_name, start_str, data = {'value': value}) + value_name = helpers._compose_name(self.name, item_name=start_str) + new_interval = ParameterAtInstant(value_name, start_str, data={"value": value}) new_values.append(new_interval) # Remove covered intervals diff --git a/openfisca_core/parameters/parameter_at_instant.py b/openfisca_core/parameters/parameter_at_instant.py index ea91d25421..ae525cf829 100644 --- a/openfisca_core/parameters/parameter_at_instant.py +++ b/openfisca_core/parameters/parameter_at_instant.py @@ -1,5 +1,4 @@ import copy -import typing from openfisca_core import commons from openfisca_core.errors import ParameterParsingError @@ -7,23 +6,22 @@ class ParameterAtInstant: - """ - A value of a parameter at a given instant. - """ + """A value of a parameter at a given instant.""" # 'unit' and 'reference' are only listed here for backward compatibility - _allowed_keys = set(['value', 'metadata', 'unit', 'reference']) + _allowed_keys = {"value", "metadata", "unit", "reference"} - def __init__(self, name, instant_str, data = None, file_path = None, metadata = None): - """ - :param string name: name of the parameter, e.g. "taxes.some_tax.some_param" - :param string instant_str: Date of the value in the format `YYYY-MM-DD`. + def __init__( + self, name, instant_str, data=None, file_path=None, metadata=None + ) -> None: + """:param str name: name of the parameter, e.g. "taxes.some_tax.some_param" + :param str instant_str: Date of the value in the format `YYYY-MM-DD`. :param dict data: Data, usually loaded from a YAML file. """ self.name: str = name self.instant_str: str = instant_str self.file_path: str = file_path - self.metadata: typing.Dict = {} + self.metadata: dict = {} # Accept { 2015-01-01: 4000 } if not isinstance(data, dict) and isinstance(data, config.ALLOWED_PARAM_TYPES): @@ -31,33 +29,44 @@ def __init__(self, name, instant_str, data = None, file_path = None, metadata = return self.validate(data) - self.value: float = data['value'] + self.value: float = data["value"] if metadata is not None: self.metadata.update(metadata) # Inherit metadata from Parameter helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) - def validate(self, data): - helpers._validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys) + def validate(self, data) -> None: + helpers._validate_parameter( + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, + ) try: - value = data['value'] + value = data["value"] except KeyError: + msg = f"Missing 'value' property for {self.name}" raise ParameterParsingError( - "Missing 'value' property for {}".format(self.name), - self.file_path - ) + msg, + self.file_path, + ) if not isinstance(value, config.ALLOWED_PARAM_TYPES): + msg = f"Value in {self.name} has type {type(value)}, which is not one of the allowed types ({config.ALLOWED_PARAM_TYPES}): {value}" raise ParameterParsingError( - "Value in {} has type {}, which is not one of the allowed types ({}): {}".format(self.name, type(value), config.ALLOWED_PARAM_TYPES, value), - self.file_path - ) + msg, + self.file_path, + ) def __eq__(self, other): - return (self.name == other.name) and (self.instant_str == other.instant_str) and (self.value == other.value) + return ( + (self.name == other.name) + and (self.instant_str == other.instant_str) + and (self.value == other.value) + ) - def __repr__(self): - return "ParameterAtInstant({})".format({self.instant_str: self.value}) + def __repr__(self) -> str: + return "ParameterAtInstant({self.instant_str: self.value})" def clone(self): clone = commons.empty_clone(self) diff --git a/openfisca_core/parameters/parameter_node.py b/openfisca_core/parameters/parameter_node.py index 1dae81dfb4..6f43379b36 100644 --- a/openfisca_core/parameters/parameter_node.py +++ b/openfisca_core/parameters/parameter_node.py @@ -1,28 +1,30 @@ from __future__ import annotations +from collections.abc import Iterable + import copy import os -import typing from openfisca_core import commons, parameters, tools -from . import config, helpers, AtInstantLike, Parameter, ParameterNodeAtInstant + +from . import config, helpers +from .at_instant_like import AtInstantLike +from .parameter import Parameter +from .parameter_node_at_instant import ParameterNodeAtInstant class ParameterNode(AtInstantLike): - """ - A node in the legislation `parameter tree `_. - """ + """A node in the legislation `parameter tree `_.""" - _allowed_keys: typing.Optional[typing.Iterable[str]] = None # By default, no restriction on the keys + _allowed_keys: None | Iterable[str] = None # By default, no restriction on the keys - def __init__(self, name = "", directory_path = None, data = None, file_path = None): - """ - Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). + def __init__(self, name="", directory_path=None, data=None, file_path=None) -> None: + """Instantiate a ParameterNode either from a dict, (using `data`), or from a directory containing YAML files (using `directory_path`). - :param string name: Name of the node, eg "taxes.some_tax". - :param string directory_path: Directory containing YAML files describing the node. + :param str name: Name of the node, eg "taxes.some_tax". + :param str directory_path: Directory containing YAML files describing the node. :param dict data: Object representing the parameter node. It usually has been extracted from a YAML file. - :param string file_path: YAML file from which the `data` has been extracted from. + :param str file_path: YAML file from which the `data` has been extracted from. Instantiate a ParameterNode from a dict: @@ -44,14 +46,20 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No Instantiate a ParameterNode from a directory containing YAML parameter files: - >>> node = ParameterNode('benefits', directory_path = '/path/to/country_package/parameters/benefits') + >>> node = ParameterNode( + ... "benefits", + ... directory_path="/path/to/country_package/parameters/benefits", + ... ) """ self.name: str = name - self.children: typing.Dict[str, typing.Union[ParameterNode, Parameter, parameters.ParameterScale]] = {} + self.children: dict[ + str, + ParameterNode | Parameter | parameters.ParameterScale, + ] = {} self.description: str = None self.documentation: str = None self.file_path: str = None - self.metadata: typing.Dict = {} + self.metadata: dict = {} if directory_path: self.file_path = directory_path @@ -64,31 +72,46 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No if ext not in config.FILE_EXTENSIONS: continue - if child_name == 'index': + if child_name == "index": data = helpers._load_yaml_file(child_path) or {} - helpers._validate_parameter(self, data, allowed_keys = config.COMMON_KEYS) - self.description = data.get('description') - self.documentation = data.get('documentation') + helpers._validate_parameter( + self, + data, + allowed_keys=config.COMMON_KEYS, + ) + self.description = data.get("description") + self.documentation = data.get("documentation") helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) else: child_name_expanded = helpers._compose_name(name, child_name) - child = helpers.load_parameter_file(child_path, child_name_expanded) + child = helpers.load_parameter_file( + child_path, + child_name_expanded, + ) self.add_child(child_name, child) elif os.path.isdir(child_path): child_name = os.path.basename(child_path) child_name_expanded = helpers._compose_name(name, child_name) - child = ParameterNode(child_name_expanded, directory_path = child_path) + child = ParameterNode( + child_name_expanded, + directory_path=child_path, + ) self.add_child(child_name, child) else: self.file_path = file_path - helpers._validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys) - self.description = data.get('description') - self.documentation = data.get('documentation') + helpers._validate_parameter( + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, + ) + self.description = data.get("description") + self.documentation = data.get("documentation") helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) for child_name, child in data.items(): if child_name in config.COMMON_KEYS: continue # do not treat reserved keys as subparameters. @@ -98,41 +121,43 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No child = helpers._parse_child(child_name_expanded, child, file_path) self.add_child(child_name, child) - def merge(self, other): - """ - Merges another ParameterNode into the current node. + def merge(self, other) -> None: + """Merges another ParameterNode into the current node. In case of child name conflict, the other node child will replace the current node child. """ for child_name, child in other.children.items(): self.add_child(child_name, child) - def add_child(self, name, child): - """ - Add a new child to the node. + def add_child(self, name, child) -> None: + """Add a new child to the node. :param name: Name of the child that must be used to access that child. Should not contain anything that could interfere with the operator `.` (dot). :param child: The new child, an instance of :class:`.ParameterScale` or :class:`.Parameter` or :class:`.ParameterNode`. """ if name in self.children: - raise ValueError("{} has already a child named {}".format(self.name, name)) - if not (isinstance(child, ParameterNode) or isinstance(child, Parameter) or isinstance(child, parameters.ParameterScale)): - raise TypeError("child must be of type ParameterNode, Parameter, or Scale. Instead got {}".format(type(child))) + msg = f"{self.name} has already a child named {name}" + raise ValueError(msg) + if not ( + isinstance(child, (ParameterNode, Parameter, parameters.ParameterScale)) + ): + msg = f"child must be of type ParameterNode, Parameter, or Scale. Instead got {type(child)}" + raise TypeError( + msg, + ) self.children[name] = child setattr(self, name, child) - def __repr__(self): - result = os.linesep.join( - [os.linesep.join( - ["{}:", "{}"]).format(name, tools.indent(repr(value))) - for name, value in sorted(self.children.items())] - ) - return result + def __repr__(self) -> str: + return os.linesep.join( + [ + os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) + for name, value in sorted(self.children.items()) + ], + ) def get_descendants(self): - """ - Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode` - """ + """Return a generator containing all the parameters and nodes recursively contained in this `ParameterNode`.""" for child in self.children.values(): yield child yield from child.get_descendants() @@ -142,10 +167,7 @@ def clone(self): clone.__dict__ = self.__dict__.copy() clone.metadata = copy.deepcopy(self.metadata) - clone.children = { - key: child.clone() - for key, child in self.children.items() - } + clone.children = {key: child.clone() for key, child in self.children.items()} for child_key, child in clone.children.items(): setattr(clone, child_key, child) diff --git a/openfisca_core/parameters/parameter_node_at_instant.py b/openfisca_core/parameters/parameter_node_at_instant.py index 49a7704c35..b66c0c1ed7 100644 --- a/openfisca_core/parameters/parameter_node_at_instant.py +++ b/openfisca_core/parameters/parameter_node_at_instant.py @@ -1,5 +1,4 @@ import os -import sys import numpy @@ -9,17 +8,13 @@ class ParameterNodeAtInstant: - """ - Parameter node of the legislation, at a given instant. - """ + """Parameter node of the legislation, at a given instant.""" - def __init__(self, name, node, instant_str): - """ - :param name: Name of the node. + def __init__(self, name, node, instant_str) -> None: + """:param name: Name of the node. :param node: Original :any:`ParameterNode` instance. :param instant_str: A date in the format `YYYY-MM-DD`. """ - # The "technical" attributes are hidden, so that the node children can be easily browsed with auto-completion without pollution self._name = name self._instant_str = instant_str @@ -30,29 +25,35 @@ def __init__(self, name, node, instant_str): if child_at_instant is not None: self.add_child(child_name, child_at_instant) - def add_child(self, child_name, child_at_instant): + def add_child(self, child_name, child_at_instant) -> None: self._children[child_name] = child_at_instant setattr(self, child_name, child_at_instant) def __getattr__(self, key): - param_name = helpers._compose_name(self._name, item_name = key) + param_name = helpers._compose_name(self._name, item_name=key) raise ParameterNotFoundError(param_name, self._instant_str) def __getitem__(self, key): # If fancy indexing is used, cast to a vectorial node if isinstance(key, numpy.ndarray): + # If fancy indexing is used wit a datetime64, cast to a vectorial node supporting datetime64 + if numpy.issubdtype(key.dtype, numpy.datetime64): + return ( + parameters.VectorialAsofDateParameterNodeAtInstant.build_from_node( + self, + )[key] + ) + return parameters.VectorialParameterNodeAtInstant.build_from_node(self)[key] return self._children[key] def __iter__(self): return iter(self._children) - def __repr__(self): - result = os.linesep.join( - [os.linesep.join( - ["{}:", "{}"]).format(name, tools.indent(repr(value))) - for name, value in self._children.items()] - ) - if sys.version_info < (3, 0): - return result - return result + def __repr__(self) -> str: + return os.linesep.join( + [ + os.linesep.join(["{}:", "{}"]).format(name, tools.indent(repr(value))) + for name, value in self._children.items() + ], + ) diff --git a/openfisca_core/parameters/parameter_scale.py b/openfisca_core/parameters/parameter_scale.py index d1cfc26379..b01b6a372a 100644 --- a/openfisca_core/parameters/parameter_scale.py +++ b/openfisca_core/parameters/parameter_scale.py @@ -1,65 +1,72 @@ import copy import os -import typing from openfisca_core import commons, parameters, tools from openfisca_core.errors import ParameterParsingError -from openfisca_core.parameters import config, helpers, AtInstantLike +from openfisca_core.parameters import AtInstantLike, config, helpers from openfisca_core.taxscales import ( LinearAverageRateTaxScale, MarginalAmountTaxScale, MarginalRateTaxScale, SingleAmountTaxScale, - ) +) class ParameterScale(AtInstantLike): - """ - A parameter scale (for instance a marginal scale). - """ + """A parameter scale (for instance a marginal scale).""" # 'unit' and 'reference' are only listed here for backward compatibility - _allowed_keys = config.COMMON_KEYS.union({'brackets'}) + _allowed_keys = config.COMMON_KEYS.union({"brackets"}) - def __init__(self, name, data, file_path): - """ - :param name: name of the scale, eg "taxes.some_scale" + def __init__(self, name, data, file_path) -> None: + """:param name: name of the scale, eg "taxes.some_scale" :param data: Data loaded from a YAML file. In case of a reform, the data can also be created dynamically. :param file_path: File the parameter was loaded from. """ self.name: str = name self.file_path: str = file_path - helpers._validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys) - self.description: str = data.get('description') - self.metadata: typing.Dict = {} + helpers._validate_parameter( + self, + data, + data_type=dict, + allowed_keys=self._allowed_keys, + ) + self.description: str = data.get("description") + self.metadata: dict = {} helpers._set_backward_compatibility_metadata(self, data) - self.metadata.update(data.get('metadata', {})) + self.metadata.update(data.get("metadata", {})) - if not isinstance(data.get('brackets', []), list): + if not isinstance(data.get("brackets", []), list): + msg = f"Property 'brackets' of scale '{self.name}' must be of type array." raise ParameterParsingError( - "Property 'brackets' of scale '{}' must be of type array." - .format(self.name), - self.file_path - ) + msg, + self.file_path, + ) brackets = [] - for i, bracket_data in enumerate(data.get('brackets', [])): - bracket_name = helpers._compose_name(name, item_name = i) - bracket = parameters.ParameterScaleBracket(name = bracket_name, data = bracket_data, file_path = file_path) + for i, bracket_data in enumerate(data.get("brackets", [])): + bracket_name = helpers._compose_name(name, item_name=i) + bracket = parameters.ParameterScaleBracket( + name=bracket_name, + data=bracket_data, + file_path=file_path, + ) brackets.append(bracket) - self.brackets: typing.List[parameters.ParameterScaleBracket] = brackets + self.brackets: list[parameters.ParameterScaleBracket] = brackets def __getitem__(self, key): if isinstance(key, int) and key < len(self.brackets): return self.brackets[key] - else: - raise KeyError(key) + raise KeyError(key) - def __repr__(self): + def __repr__(self) -> str: return os.linesep.join( - ['brackets:'] - + [tools.indent('-' + tools.indent(repr(bracket))[1:]) for bracket in self.brackets] - ) + ["brackets:"] + + [ + tools.indent("-" + tools.indent(repr(bracket))[1:]) + for bracket in self.brackets + ], + ) def get_descendants(self): return iter(()) @@ -76,45 +83,39 @@ def clone(self): def _get_at_instant(self, instant): brackets = [bracket.get_at_instant(instant) for bracket in self.brackets] - if self.metadata.get('type') == 'single_amount': + if self.metadata.get("type") == "single_amount": scale = SingleAmountTaxScale() for bracket in brackets: - if 'amount' in bracket._children and 'threshold' in bracket._children: + if "amount" in bracket._children and "threshold" in bracket._children: amount = bracket.amount threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any('amount' in bracket._children for bracket in brackets): + if any("amount" in bracket._children for bracket in brackets): scale = MarginalAmountTaxScale() for bracket in brackets: - if 'amount' in bracket._children and 'threshold' in bracket._children: + if "amount" in bracket._children and "threshold" in bracket._children: amount = bracket.amount threshold = bracket.threshold scale.add_bracket(threshold, amount) return scale - elif any('average_rate' in bracket._children for bracket in brackets): + if any("average_rate" in bracket._children for bracket in brackets): scale = LinearAverageRateTaxScale() for bracket in brackets: - if 'base' in bracket._children: - base = bracket.base - else: - base = 1. - if 'average_rate' in bracket._children and 'threshold' in bracket._children: + if ( + "average_rate" in bracket._children + and "threshold" in bracket._children + ): average_rate = bracket.average_rate threshold = bracket.threshold - scale.add_bracket(threshold, average_rate * base) - return scale - else: - scale = MarginalRateTaxScale() - - for bracket in brackets: - if 'base' in bracket._children: - base = bracket.base - else: - base = 1. - if 'rate' in bracket._children and 'threshold' in bracket._children: - rate = bracket.rate - threshold = bracket.threshold - scale.add_bracket(threshold, rate * base) + scale.add_bracket(threshold, average_rate) return scale + scale = MarginalRateTaxScale() + + for bracket in brackets: + if "rate" in bracket._children and "threshold" in bracket._children: + rate = bracket.rate + threshold = bracket.threshold + scale.add_bracket(threshold, rate) + return scale diff --git a/openfisca_core/parameters/parameter_scale_bracket.py b/openfisca_core/parameters/parameter_scale_bracket.py index 6d361d09fa..b9691ea3ca 100644 --- a/openfisca_core/parameters/parameter_scale_bracket.py +++ b/openfisca_core/parameters/parameter_scale_bracket.py @@ -2,8 +2,6 @@ class ParameterScaleBracket(ParameterNode): - """ - A parameter scale bracket. - """ + """A parameter scale bracket.""" - _allowed_keys = set(['amount', 'threshold', 'rate', 'average_rate', 'base']) + _allowed_keys = {"amount", "threshold", "rate", "average_rate"} diff --git a/openfisca_core/parameters/values_history.py b/openfisca_core/parameters/values_history.py index fc55400c89..4c56c72398 100644 --- a/openfisca_core/parameters/values_history.py +++ b/openfisca_core/parameters/values_history.py @@ -1,9 +1,5 @@ -from openfisca_core.parameters import Parameter +from .parameter import Parameter class ValuesHistory(Parameter): - """ - Only for backward compatibility. - """ - - pass + """Only for backward compatibility.""" diff --git a/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py new file mode 100644 index 0000000000..27be1f6946 --- /dev/null +++ b/openfisca_core/parameters/vectorial_asof_date_parameter_node_at_instant.py @@ -0,0 +1,81 @@ +import numpy + +from openfisca_core.parameters.parameter_node_at_instant import ParameterNodeAtInstant +from openfisca_core.parameters.vectorial_parameter_node_at_instant import ( + VectorialParameterNodeAtInstant, +) + + +class VectorialAsofDateParameterNodeAtInstant(VectorialParameterNodeAtInstant): + """Parameter node of the legislation at a given instant which has been vectorized along some date. + Vectorized parameters allow requests such as parameters.housing_benefit[date], where date is a numpy.datetime64 type vector. + """ + + @staticmethod + def build_from_node(node): + VectorialParameterNodeAtInstant.check_node_vectorisable(node) + subnodes_name = node._children.keys() + # Recursively vectorize the children of the node + vectorial_subnodes = tuple( + [ + ( + VectorialAsofDateParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], ParameterNodeAtInstant) + else node[subnode_name] + ) + for subnode_name in subnodes_name + ], + ) + # A vectorial node is a wrapper around a numpy recarray + # We first build the recarray + recarray = numpy.array( + [vectorial_subnodes], + dtype=[ + ( + subnode_name, + subnode.dtype if isinstance(subnode, numpy.recarray) else "float", + ) + for (subnode_name, subnode) in zip(subnodes_name, vectorial_subnodes) + ], + ) + return VectorialAsofDateParameterNodeAtInstant( + node._name, + recarray.view(numpy.recarray), + node._instant_str, + ) + + def __getitem__(self, key): + # If the key is a string, just get the subnode + if isinstance(key, str): + key = numpy.array([key], dtype="datetime64[D]") + return self.__getattr__(key) + # If the key is a vector, e.g. ['1990-11-25', '1983-04-17', '1969-09-09'] + if isinstance(key, numpy.ndarray): + assert numpy.issubdtype(key.dtype, numpy.datetime64) + names = list( + self.dtype.names, + ) # Get all the names of the subnodes, e.g. ['before_X', 'after_X', 'after_Y'] + values = numpy.asarray(list(self.vector[0])) + names = [name for name in names if not name.startswith("before")] + names = [ + numpy.datetime64("-".join(name[len("after_") :].split("_"))) + for name in names + ] + conditions = sum([name <= key for name in names]) + result = values[conditions] + + # If the result is not a leaf, wrap the result in a vectorial node. + if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( + result.dtype, + numpy.void, + ): + return VectorialAsofDateParameterNodeAtInstant( + self._name, + result.view(numpy.recarray), + self._instant_str, + ) + + return result + return None diff --git a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py index 845b2f9664..74cd02d378 100644 --- a/openfisca_core/parameters/vectorial_parameter_node_at_instant.py +++ b/openfisca_core/parameters/vectorial_parameter_node_at_instant.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import numpy from openfisca_core import parameters @@ -7,91 +9,90 @@ class VectorialParameterNodeAtInstant: - """ - Parameter node of the legislation at a given instant which has been vectorized. - Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector + """Parameter node of the legislation at a given instant which has been vectorized. + Vectorized parameters allow requests such as parameters.housing_benefit[zipcode], where zipcode is a vector. """ @staticmethod def build_from_node(node): VectorialParameterNodeAtInstant.check_node_vectorisable(node) - subnodes_name = node._children.keys() + subnodes_name = sorted(node._children.keys()) # Recursively vectorize the children of the node - vectorial_subnodes = tuple([ - VectorialParameterNodeAtInstant.build_from_node(node[subnode_name]).vector if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) else node[subnode_name] - for subnode_name in subnodes_name - ]) + vectorial_subnodes = tuple( + [ + ( + VectorialParameterNodeAtInstant.build_from_node( + node[subnode_name], + ).vector + if isinstance(node[subnode_name], parameters.ParameterNodeAtInstant) + else node[subnode_name] + ) + for subnode_name in subnodes_name + ], + ) # A vectorial node is a wrapper around a numpy recarray # We first build the recarray recarray = numpy.array( [vectorial_subnodes], - dtype = [ - (subnode_name, subnode.dtype if isinstance(subnode, numpy.recarray) else 'float') + dtype=[ + ( + subnode_name, + subnode.dtype if isinstance(subnode, numpy.recarray) else "float", + ) for (subnode_name, subnode) in zip(subnodes_name, vectorial_subnodes) - ] - ) + ], + ) - return VectorialParameterNodeAtInstant(node._name, recarray.view(numpy.recarray), node._instant_str) + return VectorialParameterNodeAtInstant( + node._name, + recarray.view(numpy.recarray), + node._instant_str, + ) @staticmethod - def check_node_vectorisable(node): - """ - Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing. - """ + def check_node_vectorisable(node) -> None: + """Check that a node can be casted to a vectorial node, in order to be able to use fancy indexing.""" MESSAGE_PART_1 = "Cannot use fancy indexing on parameter node '{}', as" - MESSAGE_PART_3 = "To use fancy indexing on parameter node, its children must be homogenous." + MESSAGE_PART_3 = ( + "To use fancy indexing on parameter node, its children must be homogenous." + ) MESSAGE_PART_4 = "See more at ." - def raise_key_inhomogeneity_error(node_with_key, node_without_key, missing_key): - message = " ".join([ - MESSAGE_PART_1, - "'{}' exists, but '{}' doesn't.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ]).format( + def raise_key_inhomogeneity_error( + node_with_key, node_without_key, missing_key + ) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' exists, but '{{}}' doesn't. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, - '.'.join([node_with_key, missing_key]), - '.'.join([node_without_key, missing_key]), - ) + f"{node_with_key}.{missing_key}", + f"{node_without_key}.{missing_key}", + ) raise ValueError(message) - def raise_type_inhomogeneity_error(node_name, non_node_name): - message = " ".join([ - MESSAGE_PART_1, - "'{}' is a node, but '{}' is not.", - MESSAGE_PART_3, - MESSAGE_PART_4, - ]).format( + def raise_type_inhomogeneity_error(node_name, non_node_name) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a node, but '{{}}' is not. {MESSAGE_PART_3} {MESSAGE_PART_4}".format( node._name, node_name, non_node_name, - ) + ) raise ValueError(message) - def raise_not_implemented(node_name, node_type): - message = " ".join([ - MESSAGE_PART_1, - "'{}' is a '{}', and fancy indexing has not been implemented yet on this kind of parameters.", - MESSAGE_PART_4, - ]).format( + def raise_not_implemented(node_name, node_type) -> NoReturn: + message = f"{MESSAGE_PART_1} '{{}}' is a '{{}}', and fancy indexing has not been implemented yet on this kind of parameters. {MESSAGE_PART_4}".format( node._name, node_name, node_type, - ) + ) raise NotImplementedError(message) def extract_named_children(node): return { - '.'.join([node._name, key]): value - for key, value in node._children.items() - } - - def check_nodes_homogeneous(named_nodes): - """ - Check than several nodes (or parameters, or baremes) have the same structure. - """ + f"{node._name}.{key}": value for key, value in node._children.items() + } + + def check_nodes_homogeneous(named_nodes) -> None: + """Check than several nodes (or parameters, or baremes) have the same structure.""" names = list(named_nodes.keys()) nodes = list(named_nodes.values()) first_node = nodes[0] @@ -103,18 +104,24 @@ def check_nodes_homogeneous(named_nodes): raise_type_inhomogeneity_error(first_name, name) first_node_keys = first_node._children.keys() node_keys = node._children.keys() - if not first_node_keys == node_keys: + if first_node_keys != node_keys: missing_keys = set(first_node_keys).difference(node_keys) if missing_keys: # If the first_node has a key that node hasn't - raise_key_inhomogeneity_error(first_name, name, missing_keys.pop()) + raise_key_inhomogeneity_error( + first_name, + name, + missing_keys.pop(), + ) else: # If If the node has a key that first_node doesn't have - missing_key = set(node_keys).difference(first_node_keys).pop() + missing_key = ( + set(node_keys).difference(first_node_keys).pop() + ) raise_key_inhomogeneity_error(name, first_name, missing_key) children.update(extract_named_children(node)) check_nodes_homogeneous(children) - elif isinstance(first_node, float) or isinstance(first_node, int): + elif isinstance(first_node, (float, int)): for node, name in list(zip(nodes, names))[1:]: - if isinstance(node, int) or isinstance(node, float): + if isinstance(node, (int, float)): pass elif isinstance(node, parameters.ParameterNodeAtInstant): raise_type_inhomogeneity_error(name, first_name) @@ -126,8 +133,7 @@ def check_nodes_homogeneous(named_nodes): check_nodes_homogeneous(extract_named_children(node)) - def __init__(self, name, vector, instant_str): - + def __init__(self, name, vector, instant_str) -> None: self.vector = vector self._name = name self._instant_str = instant_str @@ -143,28 +149,51 @@ def __getitem__(self, key): if isinstance(key, str): return self.__getattr__(key) # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1'] - elif isinstance(key, numpy.ndarray): + if isinstance(key, numpy.ndarray): if not numpy.issubdtype(key.dtype, numpy.str_): # In case the key is not a string vector, stringify it if key.dtype == object and issubclass(type(key[0]), Enum): enum = type(key[0]) - key = numpy.select([key == item for item in enum], [item.name for item in enum]) + key = numpy.select( + [key == item for item in enum], + [item.name for item in enum], + ) elif isinstance(key, EnumArray): enum = key.possible_values - key = numpy.select([key == item.index for item in enum], [item.name for item in enum]) + key = numpy.select( + [key == item.index for item in enum], + [item.name for item in enum], + ) else: - key = key.astype('str') - names = list(self.dtype.names) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] - default = numpy.full_like(self.vector[key[0]], numpy.nan) # In case of unexpected key, we will set the corresponding value to NaN. + key = key.astype("str") + names = list( + self.dtype.names, + ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2'] + default = numpy.full_like( + self.vector[key[0]], + numpy.nan, + ) # In case of unexpected key, we will set the corresponding value to NaN. conditions = [key == name for name in names] values = [self.vector[name] for name in names] result = numpy.select(conditions, values, default) if helpers.contains_nan(result): unexpected_key = set(key).difference(self.vector.dtype.names).pop() - raise ParameterNotFoundError('.'.join([self._name, unexpected_key]), self._instant_str) + msg = f"{self._name}.{unexpected_key}" + raise ParameterNotFoundError( + msg, + self._instant_str, + ) # If the result is not a leaf, wrap the result in a vectorial node. - if numpy.issubdtype(result.dtype, numpy.record): - return VectorialParameterNodeAtInstant(self._name, result.view(numpy.recarray), self._instant_str) + if numpy.issubdtype(result.dtype, numpy.record) or numpy.issubdtype( + result.dtype, + numpy.void, + ): + return VectorialParameterNodeAtInstant( + self._name, + result.view(numpy.recarray), + self._instant_str, + ) return result + return None diff --git a/openfisca_core/periods/__init__.py b/openfisca_core/periods/__init__.py index 4cd9db648c..2335f1792a 100644 --- a/openfisca_core/periods/__init__.py +++ b/openfisca_core/periods/__init__.py @@ -21,26 +21,59 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .config import ( # noqa: F401 - DAY, - MONTH, - YEAR, - ETERNITY, +from . import types +from ._errors import InstantError, ParserError, PeriodError +from .config import ( INSTANT_PATTERN, date_by_instant_cache, str_by_instant_cache, year_or_month_or_day_re, - ) - -from .helpers import ( # noqa: F401 - N_, +) +from .date_unit import DateUnit +from .helpers import ( instant, instant_date, - period, key_period_size, - unit_weights, + period, unit_weight, - ) + unit_weights, +) +from .instant_ import Instant +from .period_ import Period + +WEEKDAY = DateUnit.WEEKDAY +WEEK = DateUnit.WEEK +DAY = DateUnit.DAY +MONTH = DateUnit.MONTH +YEAR = DateUnit.YEAR +ETERNITY = DateUnit.ETERNITY +ISOFORMAT = DateUnit.isoformat +ISOCALENDAR = DateUnit.isocalendar -from .instant_ import Instant # noqa: F401 -from .period_ import Period # noqa: F401 +__all__ = [ + "DAY", + "DateUnit", + "ETERNITY", + "INSTANT_PATTERN", + "ISOCALENDAR", + "ISOFORMAT", + "Instant", + "InstantError", + "MONTH", + "ParserError", + "Period", + "PeriodError", + "WEEK", + "WEEKDAY", + "YEAR", + "date_by_instant_cache", + "instant", + "instant_date", + "key_period_size", + "period", + "str_by_instant_cache", + "types", + "unit_weight", + "unit_weights", + "year_or_month_or_day_re", +] diff --git a/openfisca_core/periods/_errors.py b/openfisca_core/periods/_errors.py new file mode 100644 index 0000000000..733d03ce2a --- /dev/null +++ b/openfisca_core/periods/_errors.py @@ -0,0 +1,28 @@ +from pendulum.parsing.exceptions import ParserError + + +class InstantError(ValueError): + """Raised when an invalid instant-like is provided.""" + + def __init__(self, value: str) -> None: + msg = ( + f"'{value}' is not a valid instant string. Instants are described " + "using either the 'YYYY-MM-DD' format, for instance '2015-06-15', " + "or the 'YYYY-Www-D' format, for instance '2015-W24-1'." + ) + super().__init__(msg) + + +class PeriodError(ValueError): + """Raised when an invalid period-like is provided.""" + + def __init__(self, value: str) -> None: + msg = ( + "Expected a period (eg. '2017', 'month:2017-01', 'week:2017-W01-1:3', " + f"...); got: '{value}'. Learn more about legal period formats in " + "OpenFisca: ." + ) + super().__init__(msg) + + +__all__ = ["InstantError", "ParserError", "PeriodError"] diff --git a/openfisca_core/periods/_parsers.py b/openfisca_core/periods/_parsers.py new file mode 100644 index 0000000000..9973b890a0 --- /dev/null +++ b/openfisca_core/periods/_parsers.py @@ -0,0 +1,121 @@ +"""To parse periods and instants from strings.""" + +from __future__ import annotations + +import datetime + +import pendulum + +from . import types as t +from ._errors import InstantError, ParserError, PeriodError +from .date_unit import DateUnit +from .instant_ import Instant +from .period_ import Period + + +def parse_instant(value: str) -> t.Instant: + """Parse a string into an instant. + + Args: + value (str): The string to parse. + + Returns: + An InstantStr. + + Raises: + InstantError: When the string is not a valid ISO Calendar/Format. + ParserError: When the string couldn't be parsed. + + Examples: + >>> parse_instant("2022") + Instant((2022, 1, 1)) + + >>> parse_instant("2022-02") + Instant((2022, 2, 1)) + + >>> parse_instant("2022-W02-7") + Instant((2022, 1, 16)) + + >>> parse_instant("2022-W013") + Traceback (most recent call last): + openfisca_core.periods._errors.InstantError: '2022-W013' is not a va... + + >>> parse_instant("2022-02-29") + Traceback (most recent call last): + pendulum.parsing.exceptions.ParserError: Unable to parse string [202... + + """ + + if not isinstance(value, t.InstantStr): + raise InstantError(str(value)) + + date = pendulum.parse(value, exact=True) + + if not isinstance(date, datetime.date): + msg = f"Unable to parse string [{value}]" + raise ParserError(msg) + + return Instant((date.year, date.month, date.day)) + + +def parse_period(value: str) -> t.Period: + """Parses ISO format/calendar periods. + + Such as "2012" or "2015-03". + + Examples: + >>> parse_period("2022") + Period((, Instant((2022, 1, 1)), 1)) + + >>> parse_period("2022-02") + Period((, Instant((2022, 2, 1)), 1)) + + >>> parse_period("2022-W02-7") + Period((, Instant((2022, 1, 16)), 1)) + + """ + + try: + instant = parse_instant(value) + + except InstantError as error: + raise PeriodError(value) from error + + unit = parse_unit(value) + + return Period((unit, instant, 1)) + + +def parse_unit(value: str) -> t.DateUnit: + """Determine the date unit of a date string. + + Args: + value (str): The date string to parse. + + Returns: + A DateUnit. + + Raises: + InstantError: when no DateUnit can be determined. + + Examples: + >>> parse_unit("2022") + + + >>> parse_unit("2022-W03-1") + + + """ + + if not isinstance(value, t.InstantStr): + raise InstantError(str(value)) + + length = len(value.split("-")) + + if isinstance(value, t.ISOCalendarStr): + return DateUnit.isocalendar[-length] + + return DateUnit.isoformat[-length] + + +__all__ = ["parse_instant", "parse_period", "parse_unit"] diff --git a/openfisca_core/periods/config.py b/openfisca_core/periods/config.py index 6e0c698098..4486a5caf0 100644 --- a/openfisca_core/periods/config.py +++ b/openfisca_core/periods/config.py @@ -1,15 +1,20 @@ import re -import typing -DAY = 'day' -MONTH = 'month' -YEAR = 'year' -ETERNITY = 'eternity' +import pendulum + +from . import types as t # Matches "2015", "2015-01", "2015-01-01" # Does not match "2015-13", "2015-12-32" -INSTANT_PATTERN = re.compile(r"^\d{4}(-(0[1-9]|1[012]))?(-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01]))?$") +INSTANT_PATTERN = re.compile( + r"^\d{4}(-(0[1-9]|1[012]))?(-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01]))?$", +) + +date_by_instant_cache: dict[t.Instant, pendulum.Date] = {} +str_by_instant_cache: dict[t.Instant, t.InstantStr] = {} +year_or_month_or_day_re = re.compile( + r"(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$", +) + -date_by_instant_cache: typing.Dict = {} -str_by_instant_cache: typing.Dict = {} -year_or_month_or_day_re = re.compile(r'(18|19|20)\d{2}(-(0?[1-9]|1[0-2])(-([0-2]?\d|3[0-1]))?)?$') +__all__ = ["INSTANT_PATTERN", "date_by_instant_cache", "str_by_instant_cache"] diff --git a/openfisca_core/periods/date_unit.py b/openfisca_core/periods/date_unit.py new file mode 100644 index 0000000000..c66346c3c2 --- /dev/null +++ b/openfisca_core/periods/date_unit.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from enum import EnumMeta + +from strenum import StrEnum + +from . import types as t + + +class DateUnitMeta(EnumMeta): + @property + def isoformat(self) -> tuple[t.DateUnit, ...]: + """Creates a :obj:`tuple` of ``key`` with isoformat items. + + Returns: + tuple(str): A :obj:`tuple` containing the ``keys``. + + Examples: + >>> DateUnit.isoformat + (, , >> DateUnit.DAY in DateUnit.isoformat + True + + >>> DateUnit.WEEK in DateUnit.isoformat + False + + """ + return DateUnit.DAY, DateUnit.MONTH, DateUnit.YEAR + + @property + def isocalendar(self) -> tuple[t.DateUnit, ...]: + """Creates a :obj:`tuple` of ``key`` with isocalendar items. + + Returns: + tuple(str): A :obj:`tuple` containing the ``keys``. + + Examples: + >>> DateUnit.isocalendar + (, , >> DateUnit.WEEK in DateUnit.isocalendar + True + + >>> "day" in DateUnit.isocalendar + False + + """ + return DateUnit.WEEKDAY, DateUnit.WEEK, DateUnit.YEAR + + +class DateUnit(StrEnum, metaclass=DateUnitMeta): + """The date units of a rule system. + + Examples: + >>> repr(DateUnit) + "" + + >>> repr(DateUnit.DAY) + "" + + >>> str(DateUnit.DAY) + 'day' + + >>> dict([(DateUnit.DAY, DateUnit.DAY.value)]) + {: 'day'} + + >>> list(DateUnit) + [, , >> len(DateUnit) + 6 + + >>> DateUnit["DAY"] + + + >>> DateUnit(DateUnit.DAY) + + + >>> DateUnit.DAY in DateUnit + True + + >>> "day" in list(DateUnit) + True + + >>> DateUnit.DAY == "day" + True + + >>> DateUnit.DAY.name + 'DAY' + + >>> DateUnit.DAY.value + 'day' + + """ + + def __contains__(self, other: object) -> bool: + if isinstance(other, str): + return super().__contains__(other) + return NotImplemented + + WEEKDAY = "weekday" + WEEK = "week" + DAY = "day" + MONTH = "month" + YEAR = "year" + ETERNITY = "eternity" + + +__all__ = ["DateUnit"] diff --git a/openfisca_core/periods/helpers.py b/openfisca_core/periods/helpers.py index 9ddf794d06..fab26c48ab 100644 --- a/openfisca_core/periods/helpers.py +++ b/openfisca_core/periods/helpers.py @@ -1,203 +1,313 @@ +from __future__ import annotations + +from typing import NoReturn + import datetime -import os +import functools + +import pendulum + +from . import config, types as t +from ._errors import InstantError, PeriodError +from ._parsers import parse_instant, parse_period +from .date_unit import DateUnit +from .instant_ import Instant +from .period_ import Period + + +@functools.singledispatch +def instant(value: object) -> t.Instant: + """Build a new instant, aka a triple of integers (year, month, day). + + Args: + value(object): An ``instant-like`` object. + + Returns: + :obj:`.Instant`: A new instant. + + Raises: + :exc:`ValueError`: When the arguments were invalid, like "2021-32-13". + + Examples: + >>> instant((2021,)) + Instant((2021, 1, 1)) + + >>> instant((2021, 9)) + Instant((2021, 9, 1)) + + >>> instant(datetime.date(2021, 9, 16)) + Instant((2021, 9, 16)) -from openfisca_core import periods -from openfisca_core.periods import config + >>> instant(Instant((2021, 9, 16))) + Instant((2021, 9, 16)) + >>> instant(Period((DateUnit.YEAR, Instant((2021, 9, 16)), 1))) + Instant((2021, 9, 16)) -def N_(message): - return message + >>> instant(2021) + Instant((2021, 1, 1)) + >>> instant("2021") + Instant((2021, 1, 1)) -def instant(instant): - """Return a new instant, aka a triple of integers (year, month, day). + >>> instant([2021]) + Instant((2021, 1, 1)) - >>> instant(2014) - Instant((2014, 1, 1)) - >>> instant('2014') - Instant((2014, 1, 1)) - >>> instant('2014-02') - Instant((2014, 2, 1)) - >>> instant('2014-3-2') - Instant((2014, 3, 2)) - >>> instant(instant('2014-3-2')) - Instant((2014, 3, 2)) - >>> instant(period('month', '2014-3-2')) - Instant((2014, 3, 2)) + >>> instant([2021, 9]) + Instant((2021, 9, 1)) + + >>> instant(None) + Traceback (most recent call last): + openfisca_core.periods._errors.InstantError: 'None' is not a valid i... + + """ + + if isinstance(value, t.SeqInt): + return Instant((list(value) + [1] * 3)[:3]) + + raise InstantError(str(value)) + + +@instant.register +def _(value: None) -> NoReturn: + raise InstantError(str(value)) + + +@instant.register +def _(value: int) -> t.Instant: + return Instant((value, 1, 1)) + + +@instant.register +def _(value: Period) -> t.Instant: + return value.start + + +@instant.register +def _(value: t.Instant) -> t.Instant: + return value + + +@instant.register +def _(value: datetime.date) -> t.Instant: + return Instant((value.year, value.month, value.day)) + + +@instant.register +def _(value: str) -> t.Instant: + return parse_instant(value) + + +def instant_date(instant: None | t.Instant) -> None | datetime.date: + """Returns the date representation of an ``Instant``. + + Args: + instant: An ``Instant``. + + Returns: + None: When ``instant`` is None. + datetime.date: Otherwise. + + Examples: + >>> instant_date(Instant((2021, 1, 1))) + Date(2021, 1, 1) - >>> instant(None) """ if instant is None: return None - if isinstance(instant, periods.Instant): - return instant - if isinstance(instant, str): - if not config.INSTANT_PATTERN.match(instant): - raise ValueError("'{}' is not a valid instant. Instants are described using the 'YYYY-MM-DD' format, for instance '2015-06-15'.".format(instant)) - instant = periods.Instant( - int(fragment) - for fragment in instant.split('-', 2)[:3] - ) - elif isinstance(instant, datetime.date): - instant = periods.Instant((instant.year, instant.month, instant.day)) - elif isinstance(instant, int): - instant = (instant,) - elif isinstance(instant, list): - assert 1 <= len(instant) <= 3 - instant = tuple(instant) - elif isinstance(instant, periods.Period): - instant = instant.start - else: - assert isinstance(instant, tuple), instant - assert 1 <= len(instant) <= 3 - if len(instant) == 1: - return periods.Instant((instant[0], 1, 1)) - if len(instant) == 2: - return periods.Instant((instant[0], instant[1], 1)) - return periods.Instant(instant) - - -def instant_date(instant): - if instant is None: - return None + instant_date = config.date_by_instant_cache.get(instant) + if instant_date is None: - config.date_by_instant_cache[instant] = instant_date = datetime.date(*instant) + config.date_by_instant_cache[instant] = instant_date = pendulum.date(*instant) + return instant_date -def period(value): - """Return a new period, aka a triple (unit, start_instant, size). +@functools.singledispatch +def period(value: object) -> t.Period: + """Build a new period, aka a triple (unit, start_instant, size). + + Args: + value: A ``period-like`` object. + + Returns: + :obj:`.Period`: A period. + + Raises: + :exc:`ValueError`: When the arguments were invalid, like "2021-32-13". + + Examples: + >>> period(Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1))) + Period((, Instant((2021, 1, 1)), 1)) + + >>> period(Instant((2021, 1, 1))) + Period((, Instant((2021, 1, 1)), 1)) + + >>> period(DateUnit.ETERNITY) + Period((, Instant((-1, -1, -1)), -1)) - >>> period('2014') - Period((YEAR, Instant((2014, 1, 1)), 1)) - >>> period('year:2014') - Period((YEAR, Instant((2014, 1, 1)), 1)) + >>> period(2021) + Period((, Instant((2021, 1, 1)), 1)) - >>> period('2014-2') - Period((MONTH, Instant((2014, 2, 1)), 1)) - >>> period('2014-02') - Period((MONTH, Instant((2014, 2, 1)), 1)) - >>> period('month:2014-2') - Period((MONTH, Instant((2014, 2, 1)), 1)) + >>> period("2014") + Period((, Instant((2014, 1, 1)), 1)) + + >>> period("year:2014") + Period((, Instant((2014, 1, 1)), 1)) + + >>> period("month:2014-02") + Period((, Instant((2014, 2, 1)), 1)) + + >>> period("year:2014-02") + Period((, Instant((2014, 2, 1)), 1)) + + >>> period("day:2014-02-02") + Period((, Instant((2014, 2, 2)), 1)) + + >>> period("day:2014-02-02:3") + Period((, Instant((2014, 2, 2)), 3)) - >>> period('year:2014-2') - Period((YEAR, Instant((2014, 2, 1)), 1)) """ - if isinstance(value, periods.Period): - return value - - if isinstance(value, periods.Instant): - return periods.Period((config.DAY, value, 1)) - - def parse_simple_period(value): - """ - Parses simple periods respecting the ISO format, such as 2012 or 2015-03 - """ - try: - date = datetime.datetime.strptime(value, '%Y') - except ValueError: + + one, two, three = 1, 2, 3 + + # We return an "eternity-period", for example + # ``, -1))>``. + if str(value).lower() == DateUnit.ETERNITY: + return Period.eternity() + + # We try to parse from an ISO format/calendar period. + if isinstance(value, t.InstantStr): + return parse_period(value) + + # A complex period has a ':' in its string. + if isinstance(value, t.PeriodStr): + components = value.split(":") + + # The left-most component must be a valid unit + unit = components[0] + + if unit not in list(DateUnit) or unit == DateUnit.ETERNITY: + raise PeriodError(str(value)) + + # Cast ``unit`` to DateUnit. + unit = DateUnit(unit) + + # The middle component must be a valid iso period + period = parse_period(components[1]) + + # Periods like year:2015-03 have a size of 1 + if len(components) == two: + size = one + + # if provided, make sure the size is an integer + elif len(components) == three: try: - date = datetime.datetime.strptime(value, '%Y-%m') - except ValueError: - try: - date = datetime.datetime.strptime(value, '%Y-%m-%d') - except ValueError: - return None - else: - return periods.Period((config.DAY, periods.Instant((date.year, date.month, date.day)), 1)) - else: - return periods.Period((config.MONTH, periods.Instant((date.year, date.month, 1)), 1)) + size = int(components[2]) + + except ValueError as error: + raise PeriodError(str(value)) from error + + # If there are more than 2 ":" in the string, the period is invalid else: - return periods.Period((config.YEAR, periods.Instant((date.year, date.month, 1)), 1)) - - def raise_error(value): - message = os.linesep.join([ - "Expected a period (eg. '2017', '2017-01', '2017-01-01', ...); got: '{}'.".format(value), - "Learn more about legal period formats in OpenFisca:", - "." - ]) - raise ValueError(message) - - if value == 'ETERNITY' or value == config.ETERNITY: - return periods.Period(('eternity', instant(datetime.date.min), float("inf"))) - - # check the type - if isinstance(value, int): - return periods.Period((config.YEAR, periods.Instant((value, 1, 1)), 1)) - if not isinstance(value, str): - raise_error(value) - - # try to parse as a simple period - period = parse_simple_period(value) - if period is not None: - return period - - # complex period must have a ':' in their strings - if ":" not in value: - raise_error(value) - - components = value.split(':') - - # left-most component must be a valid unit - unit = components[0] - if unit not in (config.DAY, config.MONTH, config.YEAR): - raise_error(value) - - # middle component must be a valid iso period - base_period = parse_simple_period(components[1]) - if not base_period: - raise_error(value) - - # period like year:2015-03 have a size of 1 - if len(components) == 2: - size = 1 - # if provided, make sure the size is an integer - elif len(components) == 3: - try: - size = int(components[2]) - except ValueError: - raise_error(value) - # if there is more than 2 ":" in the string, the period is invalid - else: - raise_error(value) - - # reject ambiguous period such as month:2014 - if unit_weight(base_period.unit) > unit_weight(unit): - raise_error(value) - - return periods.Period((unit, base_period.start, size)) - - -def key_period_size(period): - """ - Defines a key in order to sort periods by length. It uses two aspects : first unit then size + raise PeriodError(str(value)) + + # Reject ambiguous periods such as month:2014 + if unit_weight(period.unit) > unit_weight(unit): + raise PeriodError(str(value)) + + return Period((unit, period.start, size)) + + raise PeriodError(str(value)) + - :param period: an OpenFisca period - :return: a string +@period.register +def _(value: None) -> NoReturn: + raise PeriodError(str(value)) - >>> key_period_size(period('2014')) - '2_1' - >>> key_period_size(period('2013')) - '2_1' - >>> key_period_size(period('2014-01')) - '1_1' + +@period.register +def _(value: int) -> t.Period: + return Period((DateUnit.YEAR, instant(value), 1)) + + +@period.register +def _(value: t.Period) -> t.Period: + return value + + +@period.register +def _(value: t.Instant) -> t.Period: + return Period((DateUnit.DAY, value, 1)) + + +@period.register +def _(value: datetime.date) -> t.Period: + return Period((DateUnit.DAY, instant(value), 1)) + + +def key_period_size(period: t.Period) -> str: + """Define a key in order to sort periods by length. + + It uses two aspects: first, ``unit``, then, ``size``. + + Args: + period: An :mod:`.openfisca_core` :obj:`.Period`. + + Returns: + :obj:`str`: A string. + + Examples: + >>> instant = Instant((2021, 9, 14)) + + >>> period = Period((DateUnit.DAY, instant, 1)) + >>> key_period_size(period) + '100_1' + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> key_period_size(period) + '300_3' """ - unit, start, size = period + return f"{unit_weight(period.unit)}_{period.size}" + - return '{}_{}'.format(unit_weight(unit), size) +def unit_weights() -> dict[t.DateUnit, int]: + """Assign weights to date units. + Examples: + >>> unit_weights() + {: 100, ...ETERNITY: 'eternity'>: 400} -def unit_weights(): + """ return { - config.DAY: 100, - config.MONTH: 200, - config.YEAR: 300, - config.ETERNITY: 400, - } + DateUnit.WEEKDAY: 100, + DateUnit.WEEK: 200, + DateUnit.DAY: 100, + DateUnit.MONTH: 200, + DateUnit.YEAR: 300, + DateUnit.ETERNITY: 400, + } -def unit_weight(unit): +def unit_weight(unit: t.DateUnit) -> int: + """Retrieves a specific date unit weight. + + Examples: + >>> unit_weight(DateUnit.DAY) + 100 + + """ return unit_weights()[unit] + + +__all__ = [ + "instant", + "instant_date", + "key_period_size", + "period", + "unit_weight", + "unit_weights", +] diff --git a/openfisca_core/periods/instant_.py b/openfisca_core/periods/instant_.py index c3da65f894..f71dbb3222 100644 --- a/openfisca_core/periods/instant_.py +++ b/openfisca_core/periods/instant_.py @@ -1,249 +1,224 @@ -import calendar -import datetime +from __future__ import annotations -from openfisca_core import periods -from openfisca_core.periods import config +import pendulum +from . import config, types as t +from .date_unit import DateUnit -class Instant(tuple): - def __repr__(self): - """ - Transform instant to to its Python representation as a string. - - >>> repr(instant(2014)) - 'Instant((2014, 1, 1))' - >>> repr(instant('2014-2')) - 'Instant((2014, 2, 1))' - >>> repr(instant('2014-2-3')) - 'Instant((2014, 2, 3))' - """ - return '{}({})'.format(self.__class__.__name__, super(Instant, self).__repr__()) +class Instant(tuple[int, int, int]): + """An instant in time (year, month, day). - def __str__(self): - """ - Transform instant to a string. + An :class:`.Instant` represents the most atomic and indivisible + legislation's date unit. - >>> str(instant(2014)) - '2014-01-01' - >>> str(instant('2014-2')) - '2014-02-01' - >>> str(instant('2014-2-3')) - '2014-02-03' + Current implementation considers this unit to be a day, so + :obj:`instants <.Instant>` can be thought of as "day dates". - """ + Examples: + >>> instant = Instant((2021, 9, 13)) + + >>> repr(Instant) + "" + + >>> repr(instant) + 'Instant((2021, 9, 13))' + + >>> str(instant) + '2021-09-13' + + >>> dict([(instant, (2021, 9, 13))]) + {Instant((2021, 9, 13)): (2021, 9, 13)} + + >>> list(instant) + [2021, 9, 13] + + >>> instant[0] + 2021 + + >>> instant[0] in instant + True + + >>> len(instant) + 3 + + >>> instant == (2021, 9, 13) + True + + >>> instant != (2021, 9, 13) + False + + >>> instant > (2020, 9, 13) + True + + >>> instant < (2020, 9, 13) + False + + >>> instant >= (2020, 9, 13) + True + + >>> instant <= (2020, 9, 13) + False + + >>> instant.year + 2021 + + >>> instant.month + 9 + + >>> instant.day + 13 + + >>> instant.date + Date(2021, 9, 13) + + >>> year, month, day = instant + + """ + + __slots__ = () + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + def __str__(self) -> t.InstantStr: instant_str = config.str_by_instant_cache.get(self) + if instant_str is None: - config.str_by_instant_cache[self] = instant_str = self.date.isoformat() + instant_str = t.InstantStr(self.date.isoformat()) + config.str_by_instant_cache[self] = instant_str + return instant_str + def __lt__(self, other: object) -> bool: + if isinstance(other, Instant): + return super().__lt__(other) + return NotImplemented + + def __le__(self, other: object) -> bool: + if isinstance(other, Instant): + return super().__le__(other) + return NotImplemented + @property - def date(self): - """ - Convert instant to a date. - - >>> instant(2014).date - datetime.date(2014, 1, 1) - >>> instant('2014-2').date - datetime.date(2014, 2, 1) - >>> instant('2014-2-3').date - datetime.date(2014, 2, 3) - """ + def date(self) -> pendulum.Date: instant_date = config.date_by_instant_cache.get(self) + if instant_date is None: - config.date_by_instant_cache[self] = instant_date = datetime.date(*self) + instant_date = pendulum.date(*self) + config.date_by_instant_cache[self] = instant_date + return instant_date @property - def day(self): - """ - Extract day from instant. - - >>> instant(2014).day - 1 - >>> instant('2014-2').day - 1 - >>> instant('2014-2-3').day - 3 - """ + def day(self) -> int: return self[2] @property - def month(self): - """ - Extract month from instant. - - >>> instant(2014).month - 1 - >>> instant('2014-2').month - 2 - >>> instant('2014-2-3').month - 2 - """ + def month(self) -> int: return self[1] - def period(self, unit, size = 1): - """ - Create a new period starting at instant. - - >>> instant(2014).period('month') - Period(('month', Instant((2014, 1, 1)), 1)) - >>> instant('2014-2').period('year', 2) - Period(('year', Instant((2014, 2, 1)), 2)) - >>> instant('2014-2-3').period('day', size = 2) - Period(('day', Instant((2014, 2, 3)), 2)) - """ - assert unit in (config.DAY, config.MONTH, config.YEAR), 'Invalid unit: {} of type {}'.format(unit, type(unit)) - assert isinstance(size, int) and size >= 1, 'Invalid size: {} of type {}'.format(size, type(size)) - return periods.Period((unit, self, size)) - - def offset(self, offset, unit): - """ - Increment (or decrement) the given instant with offset units. - - >>> instant(2014).offset(1, 'day') - Instant((2014, 1, 2)) - >>> instant(2014).offset(1, 'month') - Instant((2014, 2, 1)) - >>> instant(2014).offset(1, 'year') - Instant((2015, 1, 1)) - - >>> instant('2014-1-31').offset(1, 'day') - Instant((2014, 2, 1)) - >>> instant('2014-1-31').offset(1, 'month') - Instant((2014, 2, 28)) - >>> instant('2014-1-31').offset(1, 'year') - Instant((2015, 1, 31)) - - >>> instant('2011-2-28').offset(1, 'day') - Instant((2011, 3, 1)) - >>> instant('2011-2-28').offset(1, 'month') - Instant((2011, 3, 28)) - >>> instant('2012-2-29').offset(1, 'year') - Instant((2013, 2, 28)) - - >>> instant(2014).offset(-1, 'day') - Instant((2013, 12, 31)) - >>> instant(2014).offset(-1, 'month') - Instant((2013, 12, 1)) - >>> instant(2014).offset(-1, 'year') - Instant((2013, 1, 1)) - - >>> instant('2011-3-1').offset(-1, 'day') - Instant((2011, 2, 28)) - >>> instant('2011-3-31').offset(-1, 'month') - Instant((2011, 2, 28)) - >>> instant('2012-2-29').offset(-1, 'year') - Instant((2011, 2, 28)) - - >>> instant('2014-1-30').offset(3, 'day') - Instant((2014, 2, 2)) - >>> instant('2014-10-2').offset(3, 'month') - Instant((2015, 1, 2)) - >>> instant('2014-1-1').offset(3, 'year') - Instant((2017, 1, 1)) - - >>> instant(2014).offset(-3, 'day') - Instant((2013, 12, 29)) - >>> instant(2014).offset(-3, 'month') - Instant((2013, 10, 1)) - >>> instant(2014).offset(-3, 'year') - Instant((2011, 1, 1)) - - >>> instant(2014).offset('first-of', 'month') - Instant((2014, 1, 1)) - >>> instant('2014-2').offset('first-of', 'month') - Instant((2014, 2, 1)) - >>> instant('2014-2-3').offset('first-of', 'month') - Instant((2014, 2, 1)) - - >>> instant(2014).offset('first-of', 'year') - Instant((2014, 1, 1)) - >>> instant('2014-2').offset('first-of', 'year') - Instant((2014, 1, 1)) - >>> instant('2014-2-3').offset('first-of', 'year') - Instant((2014, 1, 1)) - - >>> instant(2014).offset('last-of', 'month') - Instant((2014, 1, 31)) - >>> instant('2014-2').offset('last-of', 'month') - Instant((2014, 2, 28)) - >>> instant('2012-2-3').offset('last-of', 'month') - Instant((2012, 2, 29)) - - >>> instant(2014).offset('last-of', 'year') - Instant((2014, 12, 31)) - >>> instant('2014-2').offset('last-of', 'year') - Instant((2014, 12, 31)) - >>> instant('2014-2-3').offset('last-of', 'year') - Instant((2014, 12, 31)) - """ - year, month, day = self - assert unit in (config.DAY, config.MONTH, config.YEAR), 'Invalid unit: {} of type {}'.format(unit, type(unit)) - if offset == 'first-of': - if unit == config.MONTH: - day = 1 - elif unit == config.YEAR: - month = 1 - day = 1 - elif offset == 'last-of': - if unit == config.MONTH: - day = calendar.monthrange(year, month)[1] - elif unit == config.YEAR: - month = 12 - day = 31 - else: - assert isinstance(offset, int), 'Invalid offset: {} of type {}'.format(offset, type(offset)) - if unit == config.DAY: - day += offset - if offset < 0: - while day < 1: - month -= 1 - if month == 0: - year -= 1 - month = 12 - day += calendar.monthrange(year, month)[1] - elif offset > 0: - month_last_day = calendar.monthrange(year, month)[1] - while day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - month_last_day = calendar.monthrange(year, month)[1] - elif unit == config.MONTH: - month += offset - if offset < 0: - while month < 1: - year -= 1 - month += 12 - elif offset > 0: - while month > 12: - year += 1 - month -= 12 - month_last_day = calendar.monthrange(year, month)[1] - if day > month_last_day: - day = month_last_day - elif unit == config.YEAR: - year += offset - # Handle february month of leap year. - month_last_day = calendar.monthrange(year, month)[1] - if day > month_last_day: - day = month_last_day - - return self.__class__((year, month, day)) + @property + def year(self) -> int: + return self[0] @property - def year(self): - """ - Extract year from instant. - - >>> instant(2014).year - 2014 - >>> instant('2014-2').year - 2014 - >>> instant('2014-2-3').year - 2014 + def is_eternal(self) -> bool: + return self == self.eternity() + + def offset(self, offset: str | int, unit: t.DateUnit) -> t.Instant | None: + """Increments/decrements the given instant with offset units. + + Args: + offset: How much of ``unit`` to offset. + unit: What to offset + + Returns: + :obj:`.Instant`: A new :obj:`.Instant` in time. + + Raises: + :exc:`AssertionError`: When ``unit`` is not a date unit. + :exc:`AssertionError`: When ``offset`` is not either ``first-of``, + ``last-of``, or any :obj:`int`. + + Examples: + >>> Instant((2020, 12, 31)).offset("first-of", DateUnit.MONTH) + Instant((2020, 12, 1)) + + >>> Instant((2020, 1, 1)).offset("last-of", DateUnit.YEAR) + Instant((2020, 12, 31)) + + >>> Instant((2020, 1, 1)).offset(1, DateUnit.YEAR) + Instant((2021, 1, 1)) + + >>> Instant((2020, 1, 1)).offset(-3, DateUnit.DAY) + Instant((2019, 12, 29)) + """ - return self[0] + year, month, _ = self + + assert unit in ( + DateUnit.isoformat + DateUnit.isocalendar + ), f"Invalid unit: {unit} of type {type(unit)}" + + if offset == "first-of": + if unit == DateUnit.YEAR: + return self.__class__((year, 1, 1)) + + if unit == DateUnit.MONTH: + return self.__class__((year, month, 1)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.start_of("week") + return self.__class__((date.year, date.month, date.day)) + return None + + if offset == "last-of": + if unit == DateUnit.YEAR: + return self.__class__((year, 12, 31)) + + if unit == DateUnit.MONTH: + date = self.date + date = date.end_of("month") + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.end_of("week") + return self.__class__((date.year, date.month, date.day)) + return None + + assert isinstance( + offset, + int, + ), f"Invalid offset: {offset} of type {type(offset)}" + + if unit == DateUnit.YEAR: + date = self.date + date = date.add(years=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.MONTH: + date = self.date + date = date.add(months=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = self.date + date = date.add(weeks=offset) + return self.__class__((date.year, date.month, date.day)) + + if unit in (DateUnit.DAY, DateUnit.WEEKDAY): + date = self.date + date = date.add(days=offset) + return self.__class__((date.year, date.month, date.day)) + return None + + @classmethod + def eternity(cls) -> t.Instant: + """Return an eternity instant.""" + return cls((-1, -1, -1)) + + +__all__ = ["Instant"] diff --git a/openfisca_core/periods/period_.py b/openfisca_core/periods/period_.py index 808540f28a..00e833d861 100644 --- a/openfisca_core/periods/period_.py +++ b/openfisca_core/periods/period_.py @@ -1,124 +1,415 @@ from __future__ import annotations +from collections.abc import Sequence + import calendar +import datetime -from openfisca_core import periods -from openfisca_core.periods import config, helpers +import pendulum +from . import helpers, types as t +from .date_unit import DateUnit +from .instant_ import Instant -class Period(tuple): - """ - Toolbox to handle date intervals. - A period is a triple (unit, start, size), where unit is either "month" or "year", where start format is a - (year, month, day) triple, and where size is an integer > 1. +class Period(tuple[t.DateUnit, t.Instant, int]): + """Toolbox to handle date intervals. + + A :class:`.Period` is a triple (``unit``, ``start``, ``size``). + + Attributes: + unit (:obj:`str`): + Either ``year``, ``month``, ``day`` or ``eternity``. + start (:obj:`.Instant`): + The "instant" the :obj:`.Period` starts at. + size (:obj:`int`): + The amount of ``unit``, starting at ``start``, at least ``1``. + + Args: + (tuple(tuple(str, .Instant, int))): + The ``unit``, ``start``, and ``size``, accordingly. + + Examples: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + + >>> repr(Period) + "" + + >>> repr(period) + "Period((, Instant((2021, 10, 1)), 3))" + + >>> str(period) + 'year:2021-10:3' + + >>> dict([period, instant]) + Traceback (most recent call last): + ValueError: dictionary update sequence element #0 has length 3... + + >>> list(period) + [, Instant((2021, 10, 1)), 3] + + >>> period[0] + + + >>> period[0] in period + True + + >>> len(period) + 3 + + >>> period == Period((DateUnit.YEAR, instant, 3)) + True + + >>> period != Period((DateUnit.YEAR, instant, 3)) + False + + >>> period > Period((DateUnit.YEAR, instant, 3)) + False + + >>> period < Period((DateUnit.YEAR, instant, 3)) + False + + >>> period >= Period((DateUnit.YEAR, instant, 3)) + True + + >>> period <= Period((DateUnit.YEAR, instant, 3)) + True + + >>> period.days + 1096 + + >>> period.size_in_months + 36 + + >>> period.size_in_days + 1096 + + >>> period.stop + Instant((2024, 9, 30)) + + >>> period.unit + + + >>> period.last_3_months + Period((, Instant((2021, 7, 1)), 3)) + + >>> period.last_month + Period((, Instant((2021, 9, 1)), 1)) + + >>> period.last_year + Period((, Instant((2020, 1, 1)), 1)) + + >>> period.n_2 + Period((, Instant((2019, 1, 1)), 1)) + + >>> period.this_year + Period((, Instant((2021, 1, 1)), 1)) + + >>> period.first_month + Period((, Instant((2021, 10, 1)), 1)) + + >>> period.first_day + Period((, Instant((2021, 10, 1)), 1)) + Since a period is a triple it can be used as a dictionary key. + """ - def __repr__(self): - """ - Transform period to to its Python representation as a string. - - >>> repr(period('year', 2014)) - "Period(('year', Instant((2014, 1, 1)), 1))" - >>> repr(period('month', '2014-2')) - "Period(('month', Instant((2014, 2, 1)), 1))" - >>> repr(period('day', '2014-2-3')) - "Period(('day', Instant((2014, 2, 3)), 1))" - """ - return '{}({})'.format(self.__class__.__name__, super(Period, self).__repr__()) + __slots__ = () - def __str__(self): - """ - Transform period to a string. - - >>> str(period(YEAR, 2014)) - '2014' - - >>> str(period(YEAR, '2014-2')) - 'year:2014-02' - >>> str(period(MONTH, '2014-2')) - '2014-02' - - >>> str(period(YEAR, 2012, size = 2)) - 'year:2012:2' - >>> str(period(MONTH, 2012, size = 2)) - 'month:2012-01:2' - >>> str(period(MONTH, 2012, size = 12)) - '2012' - - >>> str(period(YEAR, '2012-3', size = 2)) - 'year:2012-03:2' - >>> str(period(MONTH, '2012-3', size = 2)) - 'month:2012-03:2' - >>> str(period(MONTH, '2012-3', size = 12)) - 'year:2012-03' - """ + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + def __str__(self) -> t.PeriodStr: unit, start_instant, size = self - if unit == config.ETERNITY: - return 'ETERNITY' - year, month, day = start_instant + + if unit == DateUnit.ETERNITY: + return t.PeriodStr(unit.upper()) + + # ISO format date units. + f_year, month, day = start_instant + + # ISO calendar date units. + c_year, week, weekday = datetime.date(f_year, month, day).isocalendar() # 1 year long period - if (unit == config.MONTH and size == 12 or unit == config.YEAR and size == 1): + if unit == DateUnit.MONTH and size == 12 or unit == DateUnit.YEAR and size == 1: if month == 1: # civil year starting from january - return str(year) - else: - # rolling year - return '{}:{}-{:02d}'.format(config.YEAR, year, month) + return t.PeriodStr(str(f_year)) + # rolling year + return t.PeriodStr(f"{DateUnit.YEAR}:{f_year}-{month:02d}") + # simple month - if unit == config.MONTH and size == 1: - return '{}-{:02d}'.format(year, month) + if unit == DateUnit.MONTH and size == 1: + return t.PeriodStr(f"{f_year}-{month:02d}") + # several civil years - if unit == config.YEAR and month == 1: - return '{}:{}:{}'.format(unit, year, size) + if unit == DateUnit.YEAR and month == 1: + return t.PeriodStr(f"{unit}:{f_year}:{size}") - if unit == config.DAY: + if unit == DateUnit.DAY: if size == 1: - return '{}-{:02d}-{:02d}'.format(year, month, day) - else: - return '{}:{}-{:02d}-{:02d}:{}'.format(unit, year, month, day, size) + return t.PeriodStr(f"{f_year}-{month:02d}-{day:02d}") + return t.PeriodStr(f"{unit}:{f_year}-{month:02d}-{day:02d}:{size}") + + # 1 week + if unit == DateUnit.WEEK and size == 1: + if week < 10: + return t.PeriodStr(f"{c_year}-W0{week}") + + return t.PeriodStr(f"{c_year}-W{week}") + + # several weeks + if unit == DateUnit.WEEK and size > 1: + if week < 10: + return t.PeriodStr(f"{unit}:{c_year}-W0{week}:{size}") + + return t.PeriodStr(f"{unit}:{c_year}-W{week}:{size}") + + # 1 weekday + if unit == DateUnit.WEEKDAY and size == 1: + if week < 10: + return t.PeriodStr(f"{c_year}-W0{week}-{weekday}") + + return t.PeriodStr(f"{c_year}-W{week}-{weekday}") + + # several weekdays + if unit == DateUnit.WEEKDAY and size > 1: + if week < 10: + return t.PeriodStr(f"{unit}:{c_year}-W0{week}-{weekday}:{size}") + + return t.PeriodStr(f"{unit}:{c_year}-W{week}-{weekday}:{size}") # complex period - return '{}:{}-{:02d}:{}'.format(unit, year, month, size) + return t.PeriodStr(f"{unit}:{f_year}-{month:02d}:{size}") @property - def date(self): - assert self.size == 1, '"date" is undefined for a period of size > 1: {}'.format(self) + def unit(self) -> t.DateUnit: + """The ``unit`` of the ``Period``. + + Example: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.unit + + + """ + return self[0] + + @property + def start(self) -> t.Instant: + """The ``Instant`` at which the ``Period`` starts. + + Example: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.start + Instant((2021, 10, 1)) + + """ + return self[1] + + @property + def size(self) -> int: + """The ``size`` of the ``Period``. + + Example: + >>> instant = Instant((2021, 10, 1)) + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size + 3 + + """ + return self[2] + + @property + def date(self) -> pendulum.Date: + """The date representation of the ``Period`` start date. + + Examples: + >>> instant = Instant((2021, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 1)) + >>> period.date + Date(2021, 10, 1) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.date + Traceback (most recent call last): + ValueError: "date" is undefined for a period of size > 1: year:2021-10:3. + + """ + if self.size != 1: + msg = f'"date" is undefined for a period of size > 1: {self}.' + raise ValueError(msg) + return self.start.date @property - def days(self): + def size_in_years(self) -> int: + """The ``size`` of the ``Period`` in years. + + Examples: + >>> instant = Instant((2021, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_years + 3 + + >>> period = Period((DateUnit.MONTH, instant, 3)) + >>> period.size_in_years + Traceback (most recent call last): + ValueError: Can't calculate number of years in a month. + """ - Count the number of days in period. - - >>> period('day', 2014).days - 365 - >>> period('month', 2014).days - 365 - >>> period('year', 2014).days - 365 - - >>> period('day', '2014-2').days - 28 - >>> period('month', '2014-2').days - 28 - >>> period('year', '2014-2').days - 365 - - >>> period('day', '2014-2-3').days - 1 - >>> period('month', '2014-2-3').days - 28 - >>> period('year', '2014-2-3').days - 365 + if self.unit == DateUnit.YEAR: + return self.size + + msg = f"Can't calculate number of years in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_months(self) -> int: + """The ``size`` of the ``Period`` in months. + + Examples: + >>> instant = Instant((2021, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_months + 36 + + >>> period = Period((DateUnit.DAY, instant, 3)) + >>> period.size_in_months + Traceback (most recent call last): + ValueError: Can't calculate number of months in a day. + + """ + if self.unit == DateUnit.YEAR: + return self.size * 12 + + if self.unit == DateUnit.MONTH: + return self.size + + msg = f"Can't calculate number of months in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_days(self) -> int: + """The ``size`` of the ``Period`` in days. + + Examples: + >>> instant = Instant((2019, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_days + 1096 + + >>> period = Period((DateUnit.MONTH, instant, 3)) + >>> period.size_in_days + 92 + """ + if self.unit in (DateUnit.YEAR, DateUnit.MONTH): + last = self.start.offset(self.size, self.unit) + if last is None: + raise NotImplementedError + last_day = last.offset(-1, DateUnit.DAY) + if last_day is None: + raise NotImplementedError + return (last_day.date - self.start.date).days + 1 + + if self.unit == DateUnit.WEEK: + return self.size * 7 + + if self.unit in (DateUnit.DAY, DateUnit.WEEKDAY): + return self.size + + msg = f"Can't calculate number of days in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_weeks(self) -> int: + """The ``size`` of the ``Period`` in weeks. + + Examples: + >>> instant = Instant((2019, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_weeks + 156 + + >>> period = Period((DateUnit.YEAR, instant, 5)) + >>> period.size_in_weeks + 261 + + """ + if self.unit == DateUnit.YEAR: + start = self.start.date + cease = start.add(years=self.size) + delta = start.diff(cease) + return delta.in_weeks() + + if self.unit == DateUnit.MONTH: + start = self.start.date + cease = start.add(months=self.size) + delta = start.diff(cease) + return delta.in_weeks() + + if self.unit == DateUnit.WEEK: + return self.size + + msg = f"Can't calculate number of weeks in a {self.unit}." + raise ValueError(msg) + + @property + def size_in_weekdays(self) -> int: + """The ``size`` of the ``Period`` in weekdays. + + Examples: + >>> instant = Instant((2019, 10, 1)) + + >>> period = Period((DateUnit.YEAR, instant, 3)) + >>> period.size_in_weekdays + 1092 + + >>> period = Period((DateUnit.WEEK, instant, 3)) + >>> period.size_in_weekdays + 21 + + """ + if self.unit == DateUnit.YEAR: + return self.size_in_weeks * 7 + + if DateUnit.MONTH in self.unit: + last = self.start.offset(self.size, self.unit) + if last is None: + raise NotImplementedError + last_day = last.offset(-1, DateUnit.DAY) + if last_day is None: + raise NotImplementedError + return (last_day.date - self.start.date).days + 1 + + if self.unit == DateUnit.WEEK: + return self.size * 7 + + if self.unit in (DateUnit.DAY, DateUnit.WEEKDAY): + return self.size + + msg = f"Can't calculate number of weekdays in a {self.unit}." + raise ValueError(msg) + + @property + def days(self) -> int: + """Same as ``size_in_days``.""" return (self.stop.date - self.start.date).days + 1 - def intersection(self, start, stop): + def intersection( + self, start: t.Instant | None, stop: t.Instant | None + ) -> t.Period | None: if start is None and stop is None: return self period_start = self[1] @@ -133,351 +424,495 @@ def intersection(self, start, stop): intersection_stop = min(period_stop, stop) if intersection_start == period_start and intersection_stop == period_stop: return self - if intersection_start.day == 1 and intersection_start.month == 1 \ - and intersection_stop.day == 31 and intersection_stop.month == 12: - return self.__class__(( - 'year', - intersection_start, - intersection_stop.year - intersection_start.year + 1, - )) - if intersection_start.day == 1 and intersection_stop.day == calendar.monthrange(intersection_stop.year, - intersection_stop.month)[1]: - return self.__class__(( - 'month', - intersection_start, + if ( + intersection_start.day == 1 + and intersection_start.month == 1 + and intersection_stop.day == 31 + and intersection_stop.month == 12 + ): + return self.__class__( + ( + DateUnit.YEAR, + intersection_start, + intersection_stop.year - intersection_start.year + 1, + ), + ) + if ( + intersection_start.day == 1 + and intersection_stop.day + == calendar.monthrange(intersection_stop.year, intersection_stop.month)[1] + ): + return self.__class__( ( - (intersection_stop.year - intersection_start.year) * 12 - + intersection_stop.month - - intersection_start.month - + 1 + DateUnit.MONTH, + intersection_start, + ( + (intersection_stop.year - intersection_start.year) * 12 + + intersection_stop.month + - intersection_start.month + + 1 ), - )) - return self.__class__(( - 'day', - intersection_start, - (intersection_stop.date - intersection_start.date).days + 1, - )) - - def get_subperiods(self, unit): - """ - Return the list of all the periods of unit ``unit`` contained in self. + ), + ) + return self.__class__( + ( + DateUnit.DAY, + intersection_start, + (intersection_stop.date - intersection_start.date).days + 1, + ), + ) + + def get_subperiods(self, unit: t.DateUnit) -> Sequence[t.Period]: + """Return the list of periods of unit ``unit`` contained in self. Examples: + >>> period = Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)) + >>> period.get_subperiods(DateUnit.MONTH) + [Period((, Instant((2021, 1, 1)), 1)),...] - >>> period('2017').get_subperiods(MONTH) - >>> [period('2017-01'), period('2017-02'), ... period('2017-12')] + >>> period = Period((DateUnit.YEAR, Instant((2021, 1, 1)), 2)) + >>> period.get_subperiods(DateUnit.YEAR) + [Period((, Instant((2021, 1, 1)), 1)), P...] - >>> period('year:2014:2').get_subperiods(YEAR) - >>> [period('2014'), period('2015')] """ if helpers.unit_weight(self.unit) < helpers.unit_weight(unit): - raise ValueError('Cannot subdivide {0} into {1}'.format(self.unit, unit)) + msg = f"Cannot subdivide {self.unit} into {unit}" + raise ValueError(msg) - if unit == config.YEAR: - return [self.this_year.offset(i, config.YEAR) for i in range(self.size)] + if unit == DateUnit.YEAR: + return [self.this_year.offset(i, DateUnit.YEAR) for i in range(self.size)] - if unit == config.MONTH: - return [self.first_month.offset(i, config.MONTH) for i in range(self.size_in_months)] + if unit == DateUnit.MONTH: + return [ + self.first_month.offset(i, DateUnit.MONTH) + for i in range(self.size_in_months) + ] - if unit == config.DAY: - return [self.first_day.offset(i, config.DAY) for i in range(self.size_in_days)] + if unit == DateUnit.DAY: + return [ + self.first_day.offset(i, DateUnit.DAY) for i in range(self.size_in_days) + ] - def offset(self, offset, unit = None): - """ - Increment (or decrement) the given period with offset units. - - >>> period('day', 2014).offset(1) - Period(('day', Instant((2014, 1, 2)), 365)) - >>> period('day', 2014).offset(1, 'day') - Period(('day', Instant((2014, 1, 2)), 365)) - >>> period('day', 2014).offset(1, 'month') - Period(('day', Instant((2014, 2, 1)), 365)) - >>> period('day', 2014).offset(1, 'year') - Period(('day', Instant((2015, 1, 1)), 365)) - - >>> period('month', 2014).offset(1) - Period(('month', Instant((2014, 2, 1)), 12)) - >>> period('month', 2014).offset(1, 'day') - Period(('month', Instant((2014, 1, 2)), 12)) - >>> period('month', 2014).offset(1, 'month') - Period(('month', Instant((2014, 2, 1)), 12)) - >>> period('month', 2014).offset(1, 'year') - Period(('month', Instant((2015, 1, 1)), 12)) - - >>> period('year', 2014).offset(1) - Period(('year', Instant((2015, 1, 1)), 1)) - >>> period('year', 2014).offset(1, 'day') - Period(('year', Instant((2014, 1, 2)), 1)) - >>> period('year', 2014).offset(1, 'month') - Period(('year', Instant((2014, 2, 1)), 1)) - >>> period('year', 2014).offset(1, 'year') - Period(('year', Instant((2015, 1, 1)), 1)) - - >>> period('day', '2011-2-28').offset(1) - Period(('day', Instant((2011, 3, 1)), 1)) - >>> period('month', '2011-2-28').offset(1) - Period(('month', Instant((2011, 3, 28)), 1)) - >>> period('year', '2011-2-28').offset(1) - Period(('year', Instant((2012, 2, 28)), 1)) - - >>> period('day', '2011-3-1').offset(-1) - Period(('day', Instant((2011, 2, 28)), 1)) - >>> period('month', '2011-3-1').offset(-1) - Period(('month', Instant((2011, 2, 1)), 1)) - >>> period('year', '2011-3-1').offset(-1) - Period(('year', Instant((2010, 3, 1)), 1)) - - >>> period('day', '2014-1-30').offset(3) - Period(('day', Instant((2014, 2, 2)), 1)) - >>> period('month', '2014-1-30').offset(3) - Period(('month', Instant((2014, 4, 30)), 1)) - >>> period('year', '2014-1-30').offset(3) - Period(('year', Instant((2017, 1, 30)), 1)) - - >>> period('day', 2014).offset(-3) - Period(('day', Instant((2013, 12, 29)), 365)) - >>> period('month', 2014).offset(-3) - Period(('month', Instant((2013, 10, 1)), 12)) - >>> period('year', 2014).offset(-3) - Period(('year', Instant((2011, 1, 1)), 1)) - - >>> period('day', '2014-2-3').offset('first-of', 'month') - Period(('day', Instant((2014, 2, 1)), 1)) - >>> period('day', '2014-2-3').offset('first-of', 'year') - Period(('day', Instant((2014, 1, 1)), 1)) - - >>> period('day', '2014-2-3', 4).offset('first-of', 'month') - Period(('day', Instant((2014, 2, 1)), 4)) - >>> period('day', '2014-2-3', 4).offset('first-of', 'year') - Period(('day', Instant((2014, 1, 1)), 4)) - - >>> period('month', '2014-2-3').offset('first-of') - Period(('month', Instant((2014, 2, 1)), 1)) - >>> period('month', '2014-2-3').offset('first-of', 'month') - Period(('month', Instant((2014, 2, 1)), 1)) - >>> period('month', '2014-2-3').offset('first-of', 'year') - Period(('month', Instant((2014, 1, 1)), 1)) - - >>> period('month', '2014-2-3', 4).offset('first-of') - Period(('month', Instant((2014, 2, 1)), 4)) - >>> period('month', '2014-2-3', 4).offset('first-of', 'month') - Period(('month', Instant((2014, 2, 1)), 4)) - >>> period('month', '2014-2-3', 4).offset('first-of', 'year') - Period(('month', Instant((2014, 1, 1)), 4)) - - >>> period('year', 2014).offset('first-of') - Period(('year', Instant((2014, 1, 1)), 1)) - >>> period('year', 2014).offset('first-of', 'month') - Period(('year', Instant((2014, 1, 1)), 1)) - >>> period('year', 2014).offset('first-of', 'year') - Period(('year', Instant((2014, 1, 1)), 1)) - - >>> period('year', '2014-2-3').offset('first-of') - Period(('year', Instant((2014, 1, 1)), 1)) - >>> period('year', '2014-2-3').offset('first-of', 'month') - Period(('year', Instant((2014, 2, 1)), 1)) - >>> period('year', '2014-2-3').offset('first-of', 'year') - Period(('year', Instant((2014, 1, 1)), 1)) - - >>> period('day', '2014-2-3').offset('last-of', 'month') - Period(('day', Instant((2014, 2, 28)), 1)) - >>> period('day', '2014-2-3').offset('last-of', 'year') - Period(('day', Instant((2014, 12, 31)), 1)) - - >>> period('day', '2014-2-3', 4).offset('last-of', 'month') - Period(('day', Instant((2014, 2, 28)), 4)) - >>> period('day', '2014-2-3', 4).offset('last-of', 'year') - Period(('day', Instant((2014, 12, 31)), 4)) - - >>> period('month', '2014-2-3').offset('last-of') - Period(('month', Instant((2014, 2, 28)), 1)) - >>> period('month', '2014-2-3').offset('last-of', 'month') - Period(('month', Instant((2014, 2, 28)), 1)) - >>> period('month', '2014-2-3').offset('last-of', 'year') - Period(('month', Instant((2014, 12, 31)), 1)) - - >>> period('month', '2014-2-3', 4).offset('last-of') - Period(('month', Instant((2014, 2, 28)), 4)) - >>> period('month', '2014-2-3', 4).offset('last-of', 'month') - Period(('month', Instant((2014, 2, 28)), 4)) - >>> period('month', '2014-2-3', 4).offset('last-of', 'year') - Period(('month', Instant((2014, 12, 31)), 4)) - - >>> period('year', 2014).offset('last-of') - Period(('year', Instant((2014, 12, 31)), 1)) - >>> period('year', 2014).offset('last-of', 'month') - Period(('year', Instant((2014, 1, 31)), 1)) - >>> period('year', 2014).offset('last-of', 'year') - Period(('year', Instant((2014, 12, 31)), 1)) - - >>> period('year', '2014-2-3').offset('last-of') - Period(('year', Instant((2014, 12, 31)), 1)) - >>> period('year', '2014-2-3').offset('last-of', 'month') - Period(('year', Instant((2014, 2, 28)), 1)) - >>> period('year', '2014-2-3').offset('last-of', 'year') - Period(('year', Instant((2014, 12, 31)), 1)) - """ - return self.__class__((self[0], self[1].offset(offset, self[0] if unit is None else unit), self[2])) + if unit == DateUnit.WEEK: + return [ + self.first_week.offset(i, DateUnit.WEEK) + for i in range(self.size_in_weeks) + ] + + if unit == DateUnit.WEEKDAY: + return [ + self.first_weekday.offset(i, DateUnit.WEEKDAY) + for i in range(self.size_in_weekdays) + ] + + msg = f"Cannot subdivide {self.unit} into {unit}" + raise ValueError(msg) + + def offset(self, offset: str | int, unit: t.DateUnit | None = None) -> t.Period: + """Increment (or decrement) the given period with offset units. + + Examples: + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset(1) + Period((, Instant((2021, 1, 2)), 365)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset( + ... 1, DateUnit.DAY + ... ) + Period((, Instant((2021, 1, 2)), 365)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset( + ... 1, DateUnit.MONTH + ... ) + Period((, Instant((2021, 2, 1)), 365)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset( + ... 1, DateUnit.YEAR + ... ) + Period((, Instant((2022, 1, 1)), 365)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(1) + Period((, Instant((2021, 2, 1)), 12)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.DAY + ... ) + Period((, Instant((2021, 1, 2)), 12)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.MONTH + ... ) + Period((, Instant((2021, 2, 1)), 12)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset( + ... 1, DateUnit.YEAR + ... ) + Period((, Instant((2022, 1, 1)), 12)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset(1) + Period((, Instant((2022, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset( + ... 1, DateUnit.DAY + ... ) + Period((, Instant((2021, 1, 2)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset( + ... 1, DateUnit.MONTH + ... ) + Period((, Instant((2021, 2, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2021, 1, 1)), 1)).offset( + ... 1, DateUnit.YEAR + ... ) + Period((, Instant((2022, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2011, 2, 28)), 1)).offset(1) + Period((, Instant((2011, 3, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2011, 2, 28)), 1)).offset(1) + Period((, Instant((2011, 3, 28)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2011, 2, 28)), 1)).offset(1) + Period((, Instant((2012, 2, 28)), 1)) + + >>> Period((DateUnit.DAY, Instant((2011, 3, 1)), 1)).offset(-1) + Period((, Instant((2011, 2, 28)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2011, 3, 1)), 1)).offset(-1) + Period((, Instant((2011, 2, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2011, 3, 1)), 1)).offset(-1) + Period((, Instant((2010, 3, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 1, 30)), 1)).offset(3) + Period((, Instant((2014, 2, 2)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 1, 30)), 1)).offset(3) + Period((, Instant((2014, 4, 30)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset(3) + Period((, Instant((2017, 1, 30)), 1)) + + >>> Period((DateUnit.DAY, Instant((2021, 1, 1)), 365)).offset(-3) + Period((, Instant((2020, 12, 29)), 365)) + + >>> Period((DateUnit.MONTH, Instant((2021, 1, 1)), 12)).offset(-3) + Period((, Instant((2020, 10, 1)), 12)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset(-3) + Period((, Instant((2011, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 4)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("first-of") + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("first-of") + Period((, Instant((2014, 2, 1)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 4)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset("first-of") + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 30)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("first-of") + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 1)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "first-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 1, 1)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 4)) + + >>> Period((DateUnit.DAY, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset("last-of") + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset("last-of") + Period((, Instant((2014, 2, 28)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 4)) + + >>> Period((DateUnit.MONTH, Instant((2014, 2, 3)), 4)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 4)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of") + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 1, 1)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 1, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset("last-of") + Period((, Instant((2014, 12, 31)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.MONTH + ... ) + Period((, Instant((2014, 2, 28)), 1)) + + >>> Period((DateUnit.YEAR, Instant((2014, 2, 3)), 1)).offset( + ... "last-of", DateUnit.YEAR + ... ) + Period((, Instant((2014, 12, 31)), 1)) - def contains(self, other: Period) -> bool: """ - Returns ``True`` if the period contains ``other``. For instance, ``period(2015)`` contains ``period(2015-01)`` + + start: None | t.Instant = self[1].offset( + offset, self[0] if unit is None else unit + ) + + if start is None: + raise NotImplementedError + + return self.__class__( + ( + self[0], + start, + self[2], + ), + ) + + def contains(self, other: t.Period) -> bool: + """Returns ``True`` if the period contains ``other``. + + For instance, ``period(2015)`` contains ``period(2015-01)``. + """ return self.start <= other.start and self.stop >= other.stop @property - def size(self): - """ - Return the size of the period. + def stop(self) -> t.Instant: + """Return the last day of the period as an Instant instance. - >>> period('month', '2012-2-29', 4).size - 4 - """ - return self[2] + Examples: + >>> Period((DateUnit.YEAR, Instant((2022, 1, 1)), 1)).stop + Instant((2022, 12, 31)) - @property - def size_in_months(self): - """ - Return the size of the period in months. + >>> Period((DateUnit.MONTH, Instant((2022, 1, 1)), 12)).stop + Instant((2022, 12, 31)) - >>> period('month', '2012-2-29', 4).size_in_months - 4 - >>> period('year', '2012', 1).size_in_months - 12 - """ - if (self[0] == config.MONTH): - return self[2] - if(self[0] == config.YEAR): - return self[2] * 12 - raise ValueError("Cannot calculate number of months in {0}".format(self[0])) + >>> Period((DateUnit.DAY, Instant((2022, 1, 1)), 365)).stop + Instant((2022, 12, 31)) - @property - def size_in_days(self): - """ - Return the size of the period in days. + >>> Period((DateUnit.YEAR, Instant((2012, 2, 29)), 1)).stop + Instant((2013, 2, 27)) - >>> period('month', '2012-2-29', 4).size_in_days - 28 - >>> period('year', '2012', 1).size_in_days - 366 - """ - unit, instant, length = self + >>> Period((DateUnit.MONTH, Instant((2012, 2, 29)), 1)).stop + Instant((2012, 3, 28)) - if unit == config.DAY: - return length - if unit in [config.MONTH, config.YEAR]: - last_day = self.start.offset(length, unit).offset(-1, config.DAY) - return (last_day.date - self.start.date).days + 1 + >>> Period((DateUnit.DAY, Instant((2012, 2, 29)), 1)).stop + Instant((2012, 2, 29)) - raise ValueError("Cannot calculate number of days in {0}".format(unit)) + >>> Period((DateUnit.YEAR, Instant((2012, 2, 29)), 2)).stop + Instant((2014, 2, 27)) - @property - def start(self) -> periods.Instant: - """ - Return the first day of the period as an Instant instance. + >>> Period((DateUnit.MONTH, Instant((2012, 2, 29)), 2)).stop + Instant((2012, 4, 28)) - >>> period('month', '2012-2-29', 4).start - Instant((2012, 2, 29)) - """ - return self[1] + >>> Period((DateUnit.DAY, Instant((2012, 2, 29)), 2)).stop + Instant((2012, 3, 1)) - @property - def stop(self) -> periods.Instant: - """ - Return the last day of the period as an Instant instance. - - >>> period('year', 2014).stop - Instant((2014, 12, 31)) - >>> period('month', 2014).stop - Instant((2014, 12, 31)) - >>> period('day', 2014).stop - Instant((2014, 12, 31)) - - >>> period('year', '2012-2-29').stop - Instant((2013, 2, 28)) - >>> period('month', '2012-2-29').stop - Instant((2012, 3, 28)) - >>> period('day', '2012-2-29').stop - Instant((2012, 2, 29)) - - >>> period('year', '2012-2-29', 2).stop - Instant((2014, 2, 28)) - >>> period('month', '2012-2-29', 2).stop - Instant((2012, 4, 28)) - >>> period('day', '2012-2-29', 2).stop - Instant((2012, 3, 1)) """ unit, start_instant, size = self - year, month, day = start_instant - if unit == config.ETERNITY: - return periods.Instant((float("inf"), float("inf"), float("inf"))) - if unit == 'day': - if size > 1: - day += size - 1 - month_last_day = calendar.monthrange(year, month)[1] - while day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - month_last_day = calendar.monthrange(year, month)[1] - else: - if unit == 'month': - month += size - while month > 12: - year += 1 - month -= 12 - else: - assert unit == 'year', 'Invalid unit: {} of type {}'.format(unit, type(unit)) - year += size - day -= 1 - if day < 1: - month -= 1 - if month == 0: - year -= 1 - month = 12 - day += calendar.monthrange(year, month)[1] - else: - month_last_day = calendar.monthrange(year, month)[1] - if day > month_last_day: - month += 1 - if month == 13: - year += 1 - month = 1 - day -= month_last_day - return periods.Instant((year, month, day)) + + if unit == DateUnit.ETERNITY: + return Instant.eternity() + + if unit == DateUnit.YEAR: + date = start_instant.date.add(years=size, days=-1) + return Instant((date.year, date.month, date.day)) + + if unit == DateUnit.MONTH: + date = start_instant.date.add(months=size, days=-1) + return Instant((date.year, date.month, date.day)) + + if unit == DateUnit.WEEK: + date = start_instant.date.add(weeks=size, days=-1) + return Instant((date.year, date.month, date.day)) + + if unit in (DateUnit.DAY, DateUnit.WEEKDAY): + date = start_instant.date.add(days=size - 1) + return Instant((date.year, date.month, date.day)) + + raise ValueError @property - def unit(self): - return self[0] + def is_eternal(self) -> bool: + return self == self.eternity() # Reference periods @property - def last_3_months(self): - return self.first_month.start.period('month', 3).offset(-3) + def last_week(self) -> t.Period: + return self.first_week.offset(-1) @property - def last_month(self): + def last_fortnight(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 1)).offset(-2) + + @property + def last_2_weeks(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 2)).offset(-2) + + @property + def last_26_weeks(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 26)).offset(-26) + + @property + def last_52_weeks(self) -> t.Period: + start: t.Instant = self.first_week.start + return self.__class__((DateUnit.WEEK, start, 52)).offset(-52) + + @property + def last_month(self) -> t.Period: return self.first_month.offset(-1) @property - def last_year(self): - return self.start.offset('first-of', 'year').period('year').offset(-1) + def last_3_months(self) -> t.Period: + start: t.Instant = self.first_month.start + return self.__class__((DateUnit.MONTH, start, 3)).offset(-3) + + @property + def last_year(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.YEAR) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.YEAR, start, 1)).offset(-1) @property - def n_2(self): - return self.start.offset('first-of', 'year').period('year').offset(-2) + def n_2(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.YEAR) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.YEAR, start, 1)).offset(-2) @property - def this_year(self): - return self.start.offset('first-of', 'year').period('year') + def this_year(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.YEAR) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.YEAR, start, 1)) + + @property + def first_month(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.MONTH) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.MONTH, start, 1)) @property - def first_month(self): - return self.start.offset('first-of', 'month').period('month') + def first_day(self) -> t.Period: + return self.__class__((DateUnit.DAY, self.start, 1)) @property - def first_day(self): - return self.start.period('day') + def first_week(self) -> t.Period: + start: None | t.Instant = self.start.offset("first-of", DateUnit.WEEK) + if start is None: + raise NotImplementedError + return self.__class__((DateUnit.WEEK, start, 1)) + + @property + def first_weekday(self) -> t.Period: + return self.__class__((DateUnit.WEEKDAY, self.start, 1)) + + @classmethod + def eternity(cls) -> t.Period: + """Return an eternity period.""" + return cls((DateUnit.ETERNITY, Instant.eternity(), -1)) + + +__all__ = ["Period"] diff --git a/openfisca_core/periods/py.typed b/openfisca_core/periods/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/periods/tests/__init__.py b/openfisca_core/periods/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/periods/tests/helpers/__init__.py b/openfisca_core/periods/tests/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/periods/tests/helpers/test_helpers.py b/openfisca_core/periods/tests/helpers/test_helpers.py new file mode 100644 index 0000000000..175ea8c873 --- /dev/null +++ b/openfisca_core/periods/tests/helpers/test_helpers.py @@ -0,0 +1,65 @@ +import datetime + +import pytest + +from openfisca_core import periods +from openfisca_core.periods import DateUnit, Instant, Period + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + (None, None), + (Instant((1, 1, 1)), datetime.date(1, 1, 1)), + (Instant((4, 2, 29)), datetime.date(4, 2, 29)), + ((1, 1, 1), datetime.date(1, 1, 1)), + ], +) +def test_instant_date(arg, expected) -> None: + assert periods.instant_date(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (Instant((-1, 1, 1)), ValueError), + (Instant((1, -1, 1)), ValueError), + (Instant((1, 13, -1)), ValueError), + (Instant((1, 1, -1)), ValueError), + (Instant((1, 1, 32)), ValueError), + (Instant((1, 2, 29)), ValueError), + (Instant(("1", 1, 1)), TypeError), + ((1,), TypeError), + ((1, 1), TypeError), + ], +) +def test_instant_date_with_an_invalid_argument(arg, error) -> None: + with pytest.raises(error): + periods.instant_date(arg) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + (Period((DateUnit.WEEKDAY, Instant((1, 1, 1)), 5)), "100_5"), + (Period((DateUnit.WEEK, Instant((1, 1, 1)), 26)), "200_26"), + (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), "100_365"), + (Period((DateUnit.MONTH, Instant((1, 1, 1)), 12)), "200_12"), + (Period((DateUnit.YEAR, Instant((1, 1, 1)), 2)), "300_2"), + (Period((DateUnit.ETERNITY, Instant((1, 1, 1)), 1)), "400_1"), + ], +) +def test_key_period_size(arg, expected) -> None: + assert periods.key_period_size(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + ((DateUnit.DAY, None, 1), AttributeError), + ((DateUnit.MONTH, None, -1000), AttributeError), + ], +) +def test_key_period_size_when_an_invalid_argument(arg, error): + with pytest.raises(error): + periods.key_period_size(arg) diff --git a/openfisca_core/periods/tests/helpers/test_instant.py b/openfisca_core/periods/tests/helpers/test_instant.py new file mode 100644 index 0000000000..fb4472814b --- /dev/null +++ b/openfisca_core/periods/tests/helpers/test_instant.py @@ -0,0 +1,73 @@ +import datetime + +import pytest + +from openfisca_core import periods +from openfisca_core.periods import DateUnit, Instant, InstantError, Period + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + (datetime.date(1, 1, 1), Instant((1, 1, 1))), + (Instant((1, 1, 1)), Instant((1, 1, 1))), + (Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), Instant((1, 1, 1))), + (-1, Instant((-1, 1, 1))), + (0, Instant((0, 1, 1))), + (1, Instant((1, 1, 1))), + (999, Instant((999, 1, 1))), + (1000, Instant((1000, 1, 1))), + ("1000", Instant((1000, 1, 1))), + ("1000-01", Instant((1000, 1, 1))), + ("1000-01-01", Instant((1000, 1, 1))), + ((-1,), Instant((-1, 1, 1))), + ((-1, -1), Instant((-1, -1, 1))), + ((-1, -1, -1), Instant((-1, -1, -1))), + ], +) +def test_instant(arg, expected) -> None: + assert periods.instant(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, InstantError), + (DateUnit.YEAR, ValueError), + (DateUnit.ETERNITY, ValueError), + ("1000-0", ValueError), + ("1000-0-0", ValueError), + ("1000-1", ValueError), + ("1000-1-1", ValueError), + ("1", ValueError), + ("a", ValueError), + ("year", ValueError), + ("eternity", ValueError), + ("999", ValueError), + ("1:1000-01-01", ValueError), + ("a:1000-01-01", ValueError), + ("year:1000-01-01", ValueError), + ("year:1000-01-01:1", ValueError), + ("year:1000-01-01:3", ValueError), + ("1000-01-01:a", ValueError), + ("1000-01-01:1", ValueError), + ((), InstantError), + ({}, InstantError), + ("", InstantError), + ((None,), InstantError), + ((None, None), InstantError), + ((None, None, None), InstantError), + ((None, None, None, None), InstantError), + (("-1",), InstantError), + (("-1", "-1"), InstantError), + (("-1", "-1", "-1"), InstantError), + (("1-1",), InstantError), + (("1-1-1",), InstantError), + ((datetime.date(1, 1, 1),), InstantError), + ((Instant((1, 1, 1)),), InstantError), + ((Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), InstantError), + ], +) +def test_instant_with_an_invalid_argument(arg, error) -> None: + with pytest.raises(error): + periods.instant(arg) diff --git a/openfisca_core/periods/tests/helpers/test_period.py b/openfisca_core/periods/tests/helpers/test_period.py new file mode 100644 index 0000000000..d2d5c6679a --- /dev/null +++ b/openfisca_core/periods/tests/helpers/test_period.py @@ -0,0 +1,134 @@ +import datetime + +import pytest + +from openfisca_core import periods +from openfisca_core.periods import DateUnit, Instant, Period, PeriodError + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("eternity", Period((DateUnit.ETERNITY, Instant((-1, -1, -1)), -1))), + ("ETERNITY", Period((DateUnit.ETERNITY, Instant((-1, -1, -1)), -1))), + ( + DateUnit.ETERNITY, + Period((DateUnit.ETERNITY, Instant((-1, -1, -1)), -1)), + ), + (datetime.date(1, 1, 1), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))), + (Instant((1, 1, 1)), Period((DateUnit.DAY, Instant((1, 1, 1)), 1))), + ( + Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), + Period((DateUnit.DAY, Instant((1, 1, 1)), 365)), + ), + (-1, Period((DateUnit.YEAR, Instant((-1, 1, 1)), 1))), + (0, Period((DateUnit.YEAR, Instant((0, 1, 1)), 1))), + (1, Period((DateUnit.YEAR, Instant((1, 1, 1)), 1))), + (999, Period((DateUnit.YEAR, Instant((999, 1, 1)), 1))), + (1000, Period((DateUnit.YEAR, Instant((1000, 1, 1)), 1))), + ("1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("1004-02-29", Period((DateUnit.DAY, Instant((1004, 2, 29)), 1))), + ("1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ("year:1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01-01", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-W01", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001-W01-1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-01-01:1", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("year:1001-W01:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001-W01-1:1", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 1))), + ("year:1001:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-01-01:3", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 3))), + ("year:1001-W01:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))), + ("year:1001-W01-1:3", Period((DateUnit.YEAR, Instant((1000, 12, 29)), 3))), + ("month:1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("month:1001-01-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("week:1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("week:1001-W01-1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("month:1001-01:1", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("month:1001-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))), + ("month:1001-01-01:3", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 3))), + ("week:1001-W01:1", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("week:1001-W01:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))), + ("week:1001-W01-1:3", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 3))), + ("day:1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("day:1001-01-01:3", Period((DateUnit.DAY, Instant((1001, 1, 1)), 3))), + ("weekday:1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ( + "weekday:1001-W01-1:3", + Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 3)), + ), + ], +) +def test_period(arg, expected) -> None: + assert periods.period(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, PeriodError), + (DateUnit.YEAR, PeriodError), + ("1", PeriodError), + ("999", PeriodError), + ("1000-0", PeriodError), + ("1000-13", PeriodError), + ("1000-W0", PeriodError), + ("1000-W54", PeriodError), + ("1000-0-0", PeriodError), + ("1000-1-0", PeriodError), + ("1000-2-31", PeriodError), + ("1000-W0-0", PeriodError), + ("1000-W1-0", PeriodError), + ("1000-W1-8", PeriodError), + ("a", PeriodError), + ("year", PeriodError), + ("1:1000", PeriodError), + ("a:1000", PeriodError), + ("month:1000", PeriodError), + ("week:1000", PeriodError), + ("day:1000-01", PeriodError), + ("weekday:1000-W1", PeriodError), + ("1000:a", PeriodError), + ("1000:1", PeriodError), + ("1000-01:1", PeriodError), + ("1000-01-01:1", PeriodError), + ("1000-W1:1", PeriodError), + ("1000-W1-1:1", PeriodError), + ("month:1000:1", PeriodError), + ("week:1000:1", PeriodError), + ("day:1000:1", PeriodError), + ("day:1000-01:1", PeriodError), + ("weekday:1000:1", PeriodError), + ("weekday:1000-W1:1", PeriodError), + ((), PeriodError), + ({}, PeriodError), + ("", PeriodError), + ((None,), PeriodError), + ((None, None), PeriodError), + ((None, None, None), PeriodError), + ((None, None, None, None), PeriodError), + ((Instant((1, 1, 1)),), PeriodError), + ((Period((DateUnit.DAY, Instant((1, 1, 1)), 365)),), PeriodError), + ((1,), PeriodError), + ((1, 1), PeriodError), + ((1, 1, 1), PeriodError), + ((-1,), PeriodError), + ((-1, -1), PeriodError), + ((-1, -1, -1), PeriodError), + (("-1",), PeriodError), + (("-1", "-1"), PeriodError), + (("-1", "-1", "-1"), PeriodError), + (("1-1",), PeriodError), + (("1-1-1",), PeriodError), + ], +) +def test_period_with_an_invalid_argument(arg, error) -> None: + with pytest.raises(error): + periods.period(arg) diff --git a/openfisca_core/periods/tests/test_instant.py b/openfisca_core/periods/tests/test_instant.py new file mode 100644 index 0000000000..e9c73ef6aa --- /dev/null +++ b/openfisca_core/periods/tests/test_instant.py @@ -0,0 +1,32 @@ +import pytest + +from openfisca_core.periods import DateUnit, Instant + + +@pytest.mark.parametrize( + ("instant", "offset", "unit", "expected"), + [ + (Instant((2020, 2, 29)), "first-of", DateUnit.YEAR, Instant((2020, 1, 1))), + (Instant((2020, 2, 29)), "first-of", DateUnit.MONTH, Instant((2020, 2, 1))), + (Instant((2020, 2, 29)), "first-of", DateUnit.WEEK, Instant((2020, 2, 24))), + (Instant((2020, 2, 29)), "first-of", DateUnit.DAY, None), + (Instant((2020, 2, 29)), "first-of", DateUnit.WEEKDAY, None), + (Instant((2020, 2, 29)), "last-of", DateUnit.YEAR, Instant((2020, 12, 31))), + (Instant((2020, 2, 29)), "last-of", DateUnit.MONTH, Instant((2020, 2, 29))), + (Instant((2020, 2, 29)), "last-of", DateUnit.WEEK, Instant((2020, 3, 1))), + (Instant((2020, 2, 29)), "last-of", DateUnit.DAY, None), + (Instant((2020, 2, 29)), "last-of", DateUnit.WEEKDAY, None), + (Instant((2020, 2, 29)), -3, DateUnit.YEAR, Instant((2017, 2, 28))), + (Instant((2020, 2, 29)), -3, DateUnit.MONTH, Instant((2019, 11, 29))), + (Instant((2020, 2, 29)), -3, DateUnit.WEEK, Instant((2020, 2, 8))), + (Instant((2020, 2, 29)), -3, DateUnit.DAY, Instant((2020, 2, 26))), + (Instant((2020, 2, 29)), -3, DateUnit.WEEKDAY, Instant((2020, 2, 26))), + (Instant((2020, 2, 29)), 3, DateUnit.YEAR, Instant((2023, 2, 28))), + (Instant((2020, 2, 29)), 3, DateUnit.MONTH, Instant((2020, 5, 29))), + (Instant((2020, 2, 29)), 3, DateUnit.WEEK, Instant((2020, 3, 21))), + (Instant((2020, 2, 29)), 3, DateUnit.DAY, Instant((2020, 3, 3))), + (Instant((2020, 2, 29)), 3, DateUnit.WEEKDAY, Instant((2020, 3, 3))), + ], +) +def test_offset(instant, offset, unit, expected) -> None: + assert instant.offset(offset, unit) == expected diff --git a/openfisca_core/periods/tests/test_parsers.py b/openfisca_core/periods/tests/test_parsers.py new file mode 100644 index 0000000000..c9131414b2 --- /dev/null +++ b/openfisca_core/periods/tests/test_parsers.py @@ -0,0 +1,129 @@ +import pytest + +from openfisca_core.periods import ( + DateUnit, + Instant, + InstantError, + ParserError, + Period, + PeriodError, + _parsers, +) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("1001", Instant((1001, 1, 1))), + ("1001-01", Instant((1001, 1, 1))), + ("1001-12", Instant((1001, 12, 1))), + ("1001-01-01", Instant((1001, 1, 1))), + ("2028-02-29", Instant((2028, 2, 29))), + ("1001-W01", Instant((1000, 12, 29))), + ("1001-W52", Instant((1001, 12, 21))), + ("1001-W01-1", Instant((1000, 12, 29))), + ], +) +def test_parse_instant(arg, expected) -> None: + assert _parsers.parse_instant(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, InstantError), + ({}, InstantError), + ((), InstantError), + ([], InstantError), + (1, InstantError), + ("", InstantError), + ("à", InstantError), + ("1", InstantError), + ("-1", InstantError), + ("999", InstantError), + ("1000-0", InstantError), + ("1000-1", ParserError), + ("1000-1-1", InstantError), + ("1000-00", InstantError), + ("1000-13", InstantError), + ("1000-01-00", InstantError), + ("1000-01-99", InstantError), + ("2029-02-29", ParserError), + ("1000-W0", InstantError), + ("1000-W1", InstantError), + ("1000-W99", InstantError), + ("1000-W1-0", InstantError), + ("1000-W1-1", InstantError), + ("1000-W1-99", InstantError), + ("1000-W01-0", InstantError), + ("1000-W01-00", InstantError), + ], +) +def test_parse_instant_with_invalid_argument(arg, error) -> None: + with pytest.raises(error): + _parsers.parse_instant(arg) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("1001", Period((DateUnit.YEAR, Instant((1001, 1, 1)), 1))), + ("1001-01", Period((DateUnit.MONTH, Instant((1001, 1, 1)), 1))), + ("1001-12", Period((DateUnit.MONTH, Instant((1001, 12, 1)), 1))), + ("1001-01-01", Period((DateUnit.DAY, Instant((1001, 1, 1)), 1))), + ("1001-W01", Period((DateUnit.WEEK, Instant((1000, 12, 29)), 1))), + ("1001-W52", Period((DateUnit.WEEK, Instant((1001, 12, 21)), 1))), + ("1001-W01-1", Period((DateUnit.WEEKDAY, Instant((1000, 12, 29)), 1))), + ], +) +def test_parse_period(arg, expected) -> None: + assert _parsers.parse_period(arg) == expected + + +@pytest.mark.parametrize( + ("arg", "error"), + [ + (None, PeriodError), + ({}, PeriodError), + ((), PeriodError), + ([], PeriodError), + (1, PeriodError), + ("", PeriodError), + ("à", PeriodError), + ("1", PeriodError), + ("-1", PeriodError), + ("999", PeriodError), + ("1000-0", PeriodError), + ("1000-1", ParserError), + ("1000-1-1", PeriodError), + ("1000-00", PeriodError), + ("1000-13", PeriodError), + ("1000-01-00", PeriodError), + ("1000-01-99", PeriodError), + ("1000-W0", PeriodError), + ("1000-W1", PeriodError), + ("1000-W99", PeriodError), + ("1000-W1-0", PeriodError), + ("1000-W1-1", PeriodError), + ("1000-W1-99", PeriodError), + ("1000-W01-0", PeriodError), + ("1000-W01-00", PeriodError), + ], +) +def test_parse_period_with_invalid_argument(arg, error) -> None: + with pytest.raises(error): + _parsers.parse_period(arg) + + +@pytest.mark.parametrize( + ("arg", "expected"), + [ + ("2022", DateUnit.YEAR), + ("2022-01", DateUnit.MONTH), + ("2022-01-01", DateUnit.DAY), + ("2022-W01", DateUnit.WEEK), + ("2022-W01-1", DateUnit.WEEKDAY), + ], +) +def test_parse_unit(arg, expected) -> None: + assert _parsers.parse_unit(arg) == expected diff --git a/openfisca_core/periods/tests/test_period.py b/openfisca_core/periods/tests/test_period.py new file mode 100644 index 0000000000..9e53bf7d12 --- /dev/null +++ b/openfisca_core/periods/tests/test_period.py @@ -0,0 +1,283 @@ +import pytest + +from openfisca_core.periods import DateUnit, Instant, Period + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 1, 1)), 1, "2022"), + (DateUnit.MONTH, Instant((2022, 1, 1)), 12, "2022"), + (DateUnit.YEAR, Instant((2022, 3, 1)), 1, "year:2022-03"), + (DateUnit.MONTH, Instant((2022, 3, 1)), 12, "year:2022-03"), + (DateUnit.YEAR, Instant((2022, 1, 1)), 3, "year:2022:3"), + (DateUnit.YEAR, Instant((2022, 1, 3)), 3, "year:2022:3"), + ], +) +def test_str_with_years(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.MONTH, Instant((2022, 1, 1)), 1, "2022-01"), + (DateUnit.MONTH, Instant((2022, 1, 1)), 3, "month:2022-01:3"), + (DateUnit.MONTH, Instant((2022, 3, 1)), 3, "month:2022-03:3"), + ], +) +def test_str_with_months(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.DAY, Instant((2022, 1, 1)), 1, "2022-01-01"), + (DateUnit.DAY, Instant((2022, 1, 1)), 3, "day:2022-01-01:3"), + (DateUnit.DAY, Instant((2022, 3, 1)), 3, "day:2022-03-01:3"), + ], +) +def test_str_with_days(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.WEEK, Instant((2022, 1, 1)), 1, "2021-W52"), + (DateUnit.WEEK, Instant((2022, 1, 1)), 3, "week:2021-W52:3"), + (DateUnit.WEEK, Instant((2022, 3, 1)), 1, "2022-W09"), + (DateUnit.WEEK, Instant((2022, 3, 1)), 3, "week:2022-W09:3"), + ], +) +def test_str_with_weeks(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.WEEKDAY, Instant((2022, 1, 1)), 1, "2021-W52-6"), + (DateUnit.WEEKDAY, Instant((2022, 1, 1)), 3, "weekday:2021-W52-6:3"), + (DateUnit.WEEKDAY, Instant((2022, 3, 1)), 1, "2022-W09-2"), + (DateUnit.WEEKDAY, Instant((2022, 3, 1)), 3, "weekday:2022-W09-2:3"), + ], +) +def test_str_with_weekdays(date_unit, instant, size, expected) -> None: + assert str(Period((date_unit, instant, size))) == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 1), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 2), + ], +) +def test_size_in_years(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_years == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 12), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 24), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 3), + ], +) +def test_size_in_months(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_months == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 365), + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 366), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 730), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31), + (DateUnit.DAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.DAY, Instant((2022, 12, 31)), 3, 3), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3), + ], +) +def test_size_in_days(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_days == expected + assert period.size_in_days == period.days + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 52), + (DateUnit.YEAR, Instant((2020, 1, 1)), 5, 261), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 4), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 4), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 12), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 13), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 3), + ], +) +def test_size_in_weeks(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_weeks == expected + + +@pytest.mark.parametrize( + ("date_unit", "instant", "size", "expected"), + [ + (DateUnit.YEAR, Instant((2022, 12, 1)), 1, 364), + (DateUnit.YEAR, Instant((2020, 1, 1)), 1, 364), + (DateUnit.YEAR, Instant((2022, 1, 1)), 2, 728), + (DateUnit.MONTH, Instant((2022, 12, 1)), 1, 31), + (DateUnit.MONTH, Instant((2020, 2, 3)), 1, 29), + (DateUnit.MONTH, Instant((2022, 1, 3)), 3, 31 + 28 + 31), + (DateUnit.MONTH, Instant((2012, 1, 3)), 3, 31 + 29 + 31), + (DateUnit.DAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.DAY, Instant((2022, 12, 31)), 3, 3), + (DateUnit.WEEK, Instant((2022, 12, 31)), 1, 7), + (DateUnit.WEEK, Instant((2022, 12, 31)), 3, 21), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 1, 1), + (DateUnit.WEEKDAY, Instant((2022, 12, 31)), 3, 3), + ], +) +def test_size_in_weekdays(date_unit, instant, size, expected) -> None: + period = Period((date_unit, instant, size)) + assert period.size_in_weekdays == expected + + +@pytest.mark.parametrize( + ("period_unit", "sub_unit", "instant", "start", "cease", "count"), + [ + ( + DateUnit.YEAR, + DateUnit.YEAR, + Instant((2022, 12, 31)), + Instant((2022, 1, 1)), + Instant((2024, 1, 1)), + 3, + ), + ( + DateUnit.YEAR, + DateUnit.MONTH, + Instant((2022, 12, 31)), + Instant((2022, 12, 1)), + Instant((2025, 11, 1)), + 36, + ), + ( + DateUnit.YEAR, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2025, 12, 30)), + 1096, + ), + ( + DateUnit.YEAR, + DateUnit.WEEK, + Instant((2022, 12, 31)), + Instant((2022, 12, 26)), + Instant((2025, 12, 15)), + 156, + ), + ( + DateUnit.YEAR, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2025, 12, 26)), + 1092, + ), + ( + DateUnit.MONTH, + DateUnit.MONTH, + Instant((2022, 12, 31)), + Instant((2022, 12, 1)), + Instant((2023, 2, 1)), + 3, + ), + ( + DateUnit.MONTH, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 3, 30)), + 90, + ), + ( + DateUnit.DAY, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ( + DateUnit.DAY, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ( + DateUnit.WEEK, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 20)), + 21, + ), + ( + DateUnit.WEEK, + DateUnit.WEEK, + Instant((2022, 12, 31)), + Instant((2022, 12, 26)), + Instant((2023, 1, 9)), + 3, + ), + ( + DateUnit.WEEK, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 20)), + 21, + ), + ( + DateUnit.WEEKDAY, + DateUnit.DAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ( + DateUnit.WEEKDAY, + DateUnit.WEEKDAY, + Instant((2022, 12, 31)), + Instant((2022, 12, 31)), + Instant((2023, 1, 2)), + 3, + ), + ], +) +def test_subperiods(period_unit, sub_unit, instant, start, cease, count) -> None: + period = Period((period_unit, instant, 3)) + subperiods = period.get_subperiods(sub_unit) + assert len(subperiods) == count + assert subperiods[0] == Period((sub_unit, start, 1)) + assert subperiods[-1] == Period((sub_unit, cease, 1)) diff --git a/openfisca_core/periods/types.py b/openfisca_core/periods/types.py new file mode 100644 index 0000000000..092509c621 --- /dev/null +++ b/openfisca_core/periods/types.py @@ -0,0 +1,183 @@ +# TODO(): Properly resolve metaclass types. +# https://github.com/python/mypy/issues/14033 + +from collections.abc import Sequence + +from openfisca_core.types import DateUnit, Instant, Period + +import re + +#: Matches "2015", "2015-01", "2015-01-01" but not "2015-13", "2015-12-32". +iso_format = re.compile(r"^\d{4}(-(?:0[1-9]|1[0-2])(-(?:0[1-9]|[12]\d|3[01]))?)?$") + +#: Matches "2015", "2015-W01", "2015-W53-1" but not "2015-W54", "2015-W10-8". +iso_calendar = re.compile(r"^\d{4}(-W(0[1-9]|[1-4][0-9]|5[0-3]))?(-[1-7])?$") + + +class _SeqIntMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return ( + bool(arg) + and isinstance(arg, Sequence) + and all(isinstance(item, int) for item in arg) + ) + + +class SeqInt(list[int], metaclass=_SeqIntMeta): # type: ignore[misc] + """A sequence of integers. + + Examples: + >>> isinstance([1, 2, 3], SeqInt) + True + + >>> isinstance((1, 2, 3), SeqInt) + True + + >>> isinstance({1, 2, 3}, SeqInt) + False + + >>> isinstance([1, 2, "3"], SeqInt) + False + + >>> isinstance(1, SeqInt) + False + + >>> isinstance([], SeqInt) + False + + """ + + +class _InstantStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return isinstance(arg, (ISOFormatStr, ISOCalendarStr)) + + +class InstantStr(str, metaclass=_InstantStrMeta): # type: ignore[misc] + """A string representing an instant in string format. + + Examples: + >>> isinstance("2015", InstantStr) + True + + >>> isinstance("2015-01", InstantStr) + True + + >>> isinstance("2015-W01", InstantStr) + True + + >>> isinstance("2015-W01-12", InstantStr) + False + + >>> isinstance("week:2015-W01:3", InstantStr) + False + + """ + + __slots__ = () + + +class _ISOFormatStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return isinstance(arg, str) and bool(iso_format.match(arg)) + + +class ISOFormatStr(str, metaclass=_ISOFormatStrMeta): # type: ignore[misc] + """A string representing an instant in ISO format. + + Examples: + >>> isinstance("2015", ISOFormatStr) + True + + >>> isinstance("2015-01", ISOFormatStr) + True + + >>> isinstance("2015-01-01", ISOFormatStr) + True + + >>> isinstance("2015-13", ISOFormatStr) + False + + >>> isinstance("2015-W01", ISOFormatStr) + False + + """ + + __slots__ = () + + +class _ISOCalendarStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return isinstance(arg, str) and bool(iso_calendar.match(arg)) + + +class ISOCalendarStr(str, metaclass=_ISOCalendarStrMeta): # type: ignore[misc] + """A string representing an instant in ISO calendar. + + Examples: + >>> isinstance("2015", ISOCalendarStr) + True + + >>> isinstance("2015-W01", ISOCalendarStr) + True + + >>> isinstance("2015-W11-7", ISOCalendarStr) + True + + >>> isinstance("2015-W010", ISOCalendarStr) + False + + >>> isinstance("2015-01", ISOCalendarStr) + False + + """ + + __slots__ = () + + +class _PeriodStrMeta(type): + def __instancecheck__(self, arg: object) -> bool: + return ( + isinstance(arg, str) + and ":" in arg + and isinstance(arg.split(":")[1], InstantStr) + ) + + +class PeriodStr(str, metaclass=_PeriodStrMeta): # type: ignore[misc] + """A string representing a period. + + Examples: + >>> isinstance("year", PeriodStr) + False + + >>> isinstance("2015", PeriodStr) + False + + >>> isinstance("year:2015", PeriodStr) + True + + >>> isinstance("month:2015-01", PeriodStr) + True + + >>> isinstance("weekday:2015-W01-1:365", PeriodStr) + True + + >>> isinstance("2015-W01:1", PeriodStr) + False + + """ + + __slots__ = () + + +__all__ = [ + "DateUnit", + "ISOCalendarStr", + "ISOFormatStr", + "Instant", + "InstantStr", + "Period", + "PeriodStr", + "SeqInt", +] diff --git a/openfisca_core/populations/__init__.py b/openfisca_core/populations/__init__.py index 7dedd71dc6..0047c528b6 100644 --- a/openfisca_core/populations/__init__.py +++ b/openfisca_core/populations/__init__.py @@ -21,18 +21,27 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.projectors import ( # noqa: F401 - Projector, +from openfisca_core.projectors import ( EntityToPersonProjector, FirstPersonToEntityProjector, + Projector, UniqueRoleToEntityProjector, - ) +) +from openfisca_core.projectors.helpers import get_projector_from_shortcut, projectable -from openfisca_core.projectors.helpers import ( # noqa: F401 - projectable, - get_projector_from_shortcut, - ) +from .config import ADD, DIVIDE +from .group_population import GroupPopulation +from .population import Population -from .config import ADD, DIVIDE # noqa: F401 -from .population import Population # noqa: F401 -from .group_population import GroupPopulation # noqa: F401 +__all__ = [ + "ADD", + "DIVIDE", + "EntityToPersonProjector", + "FirstPersonToEntityProjector", + "GroupPopulation", + "Population", + "Projector", + "UniqueRoleToEntityProjector", + "get_projector_from_shortcut", + "projectable", +] diff --git a/openfisca_core/populations/config.py b/openfisca_core/populations/config.py index 92a0b28865..b3e90875c8 100644 --- a/openfisca_core/populations/config.py +++ b/openfisca_core/populations/config.py @@ -1,2 +1,2 @@ -ADD = 'add' -DIVIDE = 'divide' +ADD = "add" +DIVIDE = "divide" diff --git a/openfisca_core/populations/group_population.py b/openfisca_core/populations/group_population.py index 81db310bfe..4e68762f19 100644 --- a/openfisca_core/populations/group_population.py +++ b/openfisca_core/populations/group_population.py @@ -2,14 +2,13 @@ import numpy -from openfisca_core import entities, projectors -from openfisca_core.entities import Role -from openfisca_core.indexed_enums import EnumArray -from openfisca_core.populations import Population +from openfisca_core import entities, indexed_enums, projectors + +from .population import Population class GroupPopulation(Population): - def __init__(self, entity, members): + def __init__(self, entity, members) -> None: super().__init__(entity) self.members = members self._members_entity_id = None @@ -20,7 +19,9 @@ def __init__(self, entity, members): def clone(self, simulation): result = GroupPopulation(self.entity, self.members) result.simulation = simulation - result._holders = {variable: holder.clone(self) for (variable, holder) in self._holders.items()} + result._holders = { + variable: holder.clone(self) for (variable, holder) in self._holders.items() + } result.count = self.count result.ids = self.ids result._members_entity_id = self._members_entity_id @@ -45,7 +46,7 @@ def members_position(self): return self._members_position @members_position.setter - def members_position(self, members_position): + def members_position(self, members_position) -> None: self._members_position = members_position @property @@ -53,7 +54,7 @@ def members_entity_id(self): return self._members_entity_id @members_entity_id.setter - def members_entity_id(self, members_entity_id): + def members_entity_id(self, members_entity_id) -> None: self._members_entity_id = members_entity_id @property @@ -64,14 +65,13 @@ def members_role(self): return self._members_role @members_role.setter - def members_role(self, members_role: typing.Iterable[Role]): + def members_role(self, members_role: typing.Iterable[entities.Role]) -> None: if members_role is not None: self._members_role = numpy.array(list(members_role)) @property def ordered_members_map(self): - """ - Mask to group the persons by entity + """Mask to group the persons by entity This function only caches the map value, to see what the map is used for, see value_nth_person method. """ if self._ordered_members_map is None: @@ -79,167 +79,193 @@ def ordered_members_map(self): return self._ordered_members_map def get_role(self, role_name): - return next((role for role in self.entity.flattened_roles if role.key == role_name), None) + return next( + (role for role in self.entity.flattened_roles if role.key == role_name), + None, + ) # Aggregation persons -> entity @projectors.projectable - def sum(self, array, role = None): - """ - Return the sum of ``array`` for the members of the entity. + def sum(self, array, role=None): + """Return the sum of ``array`` for the members of the entity. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.sum(salaries) + >>> array([3500]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.sum(salaries) - >>> array([3500]) """ - entities.check_role_validity(role) + self.entity.check_role_validity(role) self.members.check_array_compatible_with_entity(array) if role is not None: role_filter = self.members.has_role(role) return numpy.bincount( self.members_entity_id[role_filter], - weights = array[role_filter], - minlength = self.count) - else: - return numpy.bincount(self.members_entity_id, weights = array) + weights=array[role_filter], + minlength=self.count, + ) + return numpy.bincount(self.members_entity_id, weights=array) @projectors.projectable - def any(self, array, role = None): - """ - Return ``True`` if ``array`` is ``True`` for any members of the entity. + def any(self, array, role=None): + """Return ``True`` if ``array`` is ``True`` for any members of the entity. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.any(salaries >= 1800) + >>> array([True]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.any(salaries >= 1800) - >>> array([True]) """ - sum_in_entity = self.sum(array, role = role) - return (sum_in_entity > 0) + sum_in_entity = self.sum(array, role=role) + return sum_in_entity > 0 @projectors.projectable - def reduce(self, array, reducer, neutral_element, role = None): + def reduce(self, array, reducer, neutral_element, role=None): self.members.check_array_compatible_with_entity(array) - entities.check_role_validity(role) + self.entity.check_role_validity(role) position_in_entity = self.members_position role_filter = self.members.has_role(role) if role is not None else True filtered_array = numpy.where(role_filter, array, neutral_element) - result = self.filled_array(neutral_element) # Neutral value that will be returned if no one with the given role exists. + result = self.filled_array( + neutral_element, + ) # Neutral value that will be returned if no one with the given role exists. # We loop over the positions in the entity # Looping over the entities is tempting, but potentielly slow if there are a lot of entities biggest_entity_size = numpy.max(position_in_entity) + 1 for p in range(biggest_entity_size): - values = self.value_nth_person(p, filtered_array, default = neutral_element) + values = self.value_nth_person(p, filtered_array, default=neutral_element) result = reducer(result, values) return result @projectors.projectable - def all(self, array, role = None): - """ - Return ``True`` if ``array`` is ``True`` for all members of the entity. + def all(self, array, role=None): + """Return ``True`` if ``array`` is ``True`` for all members of the entity. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.all(salaries >= 1800) + >>> array([False]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.all(salaries >= 1800) - >>> array([False]) """ - return self.reduce(array, reducer = numpy.logical_and, neutral_element = True, role = role) + return self.reduce( + array, + reducer=numpy.logical_and, + neutral_element=True, + role=role, + ) @projectors.projectable - def max(self, array, role = None): - """ - Return the maximum value of ``array`` for the entity members. + def max(self, array, role=None): + """Return the maximum value of ``array`` for the entity members. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.max(salaries) + >>> array([2000]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.max(salaries) - >>> array([2000]) """ - return self.reduce(array, reducer = numpy.maximum, neutral_element = - numpy.infty, role = role) + return self.reduce( + array, + reducer=numpy.maximum, + neutral_element=-numpy.inf, + role=role, + ) @projectors.projectable - def min(self, array, role = None): - """ - Return the minimum value of ``array`` for the entity members. + def min(self, array, role=None): + """Return the minimum value of ``array`` for the entity members. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. - Example: + Example: + >>> salaries = household.members( + ... "salary", "2018-01" + ... ) # e.g. [2000, 1500, 0, 0, 0] + >>> household.min(salaries) + >>> array([0]) + >>> household.min( + ... salaries, role=Household.PARENT + ... ) # Assuming the 1st two persons are parents + >>> array([1500]) - >>> salaries = household.members('salary', '2018-01') # e.g. [2000, 1500, 0, 0, 0] - >>> household.min(salaries) - >>> array([0]) - >>> household.min(salaries, role = Household.PARENT) # Assuming the 1st two persons are parents - >>> array([1500]) """ - return self.reduce(array, reducer = numpy.minimum, neutral_element = numpy.infty, role = role) + return self.reduce( + array, + reducer=numpy.minimum, + neutral_element=numpy.inf, + role=role, + ) @projectors.projectable - def nb_persons(self, role = None): - """ - Returns the number of persons contained in the entity. + def nb_persons(self, role=None): + """Returns the number of persons contained in the entity. - If ``role`` is provided, only the entity member with the given role are taken into account. + If ``role`` is provided, only the entity member with the given role are taken into account. """ if role: if role.subroles: - role_condition = numpy.logical_or.reduce([self.members_role == subrole for subrole in role.subroles]) + role_condition = numpy.logical_or.reduce( + [self.members_role == subrole for subrole in role.subroles], + ) else: role_condition = self.members_role == role return self.sum(role_condition) - else: - return numpy.bincount(self.members_entity_id) + return numpy.bincount(self.members_entity_id) # Projection person -> entity @projectors.projectable - def value_from_person(self, array, role, default = 0): - """ - Get the value of ``array`` for the person with the unique role ``role``. + def value_from_person(self, array, role, default=0): + """Get the value of ``array`` for the person with the unique role ``role``. - ``array`` must have the dimension of the number of persons in the simulation + ``array`` must have the dimension of the number of persons in the simulation - If such a person does not exist, return ``default`` instead + If such a person does not exist, return ``default`` instead - The result is a vector which dimension is the number of entities + The result is a vector which dimension is the number of entities """ - entities.check_role_validity(role) + self.entity.check_role_validity(role) if role.max != 1: + msg = f"You can only use value_from_person with a role that is unique in {self.key}. Role {role.key} is not unique." raise Exception( - 'You can only use value_from_person with a role that is unique in {}. Role {} is not unique.' - .format(self.key, role.key) - ) + msg, + ) self.members.check_array_compatible_with_entity(array) members_map = self.ordered_members_map - result = self.filled_array(default, dtype = array.dtype) - if isinstance(array, EnumArray): - result = EnumArray(result, array.possible_values) + result = self.filled_array(default, dtype=array.dtype) + if isinstance(array, indexed_enums.EnumArray): + result = indexed_enums.EnumArray(result, array.possible_values) role_filter = self.members.has_role(role) entity_filter = self.any(role_filter) @@ -248,24 +274,28 @@ def value_from_person(self, array, role, default = 0): return result @projectors.projectable - def value_nth_person(self, n, array, default = 0): - """ - Get the value of array for the person whose position in the entity is n. + def value_nth_person(self, n, array, default=0): + """Get the value of array for the person whose position in the entity is n. - Note that this position is arbitrary, and that members are not sorted. + Note that this position is arbitrary, and that members are not sorted. - If the nth person does not exist, return ``default`` instead. + If the nth person does not exist, return ``default`` instead. - The result is a vector which dimension is the number of entities. + The result is a vector which dimension is the number of entities. """ self.members.check_array_compatible_with_entity(array) positions = self.members_position nb_persons_per_entity = self.nb_persons() members_map = self.ordered_members_map - result = self.filled_array(default, dtype = array.dtype) + result = self.filled_array(default, dtype=array.dtype) # For households that have at least n persons, set the result as the value of criteria for the person for which the position is n. # The map is needed b/c the order of the nth persons of each household in the persons vector is not necessarily the same than the household order. - result[nb_persons_per_entity > n] = array[members_map][positions[members_map] == n] + result[nb_persons_per_entity > n] = array[members_map][ + positions[members_map] == n + ] + + if isinstance(array, indexed_enums.EnumArray): + result = indexed_enums.EnumArray(result, array.possible_values) return result @@ -275,11 +305,10 @@ def value_from_first_person(self, array): # Projection entity -> person(s) - def project(self, array, role = None): + def project(self, array, role=None): self.check_array_compatible_with_entity(array) - entities.check_role_validity(role) + self.entity.check_role_validity(role) if role is None: return array[self.members_entity_id] - else: - role_condition = self.members.has_role(role) - return numpy.where(role_condition, array[self.members_entity_id], 0) + role_condition = self.members.has_role(role) + return numpy.where(role_condition, array[self.members_entity_id], 0) diff --git a/openfisca_core/populations/population.py b/openfisca_core/populations/population.py index 8b2471e0f3..06acc05d28 100644 --- a/openfisca_core/populations/population.py +++ b/openfisca_core/populations/population.py @@ -1,141 +1,225 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import NamedTuple +from typing_extensions import TypedDict + +from openfisca_core.types import Array, Period, Role, Simulation, SingleEntity + import traceback import numpy -from openfisca_core import entities, projectors -from openfisca_core.holders import Holder -from openfisca_core.populations import config -from openfisca_core.projectors import Projector +from openfisca_core import holders, periods, projectors + +from . import config class Population: - def __init__(self, entity): + simulation: Simulation | None + entity: SingleEntity + _holders: dict[str, holders.Holder] + count: int + ids: Array[str] + + def __init__(self, entity: SingleEntity) -> None: self.simulation = None self.entity = entity self._holders = {} self.count = 0 self.ids = [] - def clone(self, simulation): + def clone(self, simulation: Simulation) -> Population: result = Population(self.entity) result.simulation = simulation - result._holders = {variable: holder.clone(result) for (variable, holder) in self._holders.items()} + result._holders = { + variable: holder.clone(result) + for (variable, holder) in self._holders.items() + } result.count = self.count result.ids = self.ids return result - def empty_array(self): + def empty_array(self) -> Array[float]: return numpy.zeros(self.count) - def filled_array(self, value, dtype = None): + def filled_array( + self, + value: float | bool, + dtype: numpy.dtype | None = None, + ) -> Array[float] | Array[bool]: return numpy.full(self.count, value, dtype) - def __getattr__(self, attribute): + def __getattr__(self, attribute: str) -> projectors.Projector: + projector: projectors.Projector | None projector = projectors.get_projector_from_shortcut(self, attribute) - if not projector: - raise AttributeError("You tried to use the '{}' of '{}' but that is not a known attribute.".format(attribute, self.entity.key)) - return projector - def get_index(self, id): + if isinstance(projector, projectors.Projector): + return projector + + msg = f"You tried to use the '{attribute}' of '{self.entity.key}' but that is not a known attribute." + raise AttributeError( + msg, + ) + + def get_index(self, id: str) -> int: return self.ids.index(id) # Calculations - def check_array_compatible_with_entity(self, array): - if not self.count == array.size: - raise ValueError("Input {} is not a valid value for the entity {} (size = {} != {} = count)".format( - array, self.key, array.size, self.count)) - - def check_period_validity(self, variable_name, period): - if period is None: - stack = traceback.extract_stack() - filename, line_number, function_name, line_of_code = stack[-3] - raise ValueError(''' -You requested computation of variable "{}", but you did not specify on which period in "{}:{}": - {} + def check_array_compatible_with_entity( + self, + array: Array[float], + ) -> None: + if self.count == array.size: + return + + msg = f"Input {array} is not a valid value for the entity {self.entity.key} (size = {array.size} != {self.count} = count)" + raise ValueError( + msg, + ) + + def check_period_validity( + self, + variable_name: str, + period: int | str | Period | None, + ) -> None: + if isinstance(period, (int, str, periods.Period)): + return + + stack = traceback.extract_stack() + filename, line_number, function_name, line_of_code = stack[-3] + msg = f""" +You requested computation of variable "{variable_name}", but you did not specify on which period in "{filename}:{line_number}": + {line_of_code} When you request the computation of a variable within a formula, you must always specify the period as the second parameter. The convention is to call this parameter "period". For example: computed_salary = person('salary', period). See more information at . -'''.format(variable_name, filename, line_number, line_of_code)) - - def __call__(self, variable_name, period = None, options = None): - """ - Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. +""" + raise ValueError( + msg, + ) + + def __call__( + self, + variable_name: str, + period: int | str | Period | None = None, + options: Sequence[str] | None = None, + ) -> Array[float] | None: + """Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists. - Example: + Example: + >>> person("salary", "2017-04") + >>> array([300.0]) - >>> person('salary', '2017-04') - >>> array([300.]) + :returns: A numpy array containing the result of the calculation - :returns: A numpy array containing the result of the calculation """ - self.entity.variables.isdefined().get(variable_name) - self.check_period_validity(variable_name, period) + if self.simulation is None: + return None + + calculate: Calculate = Calculate( + variable=variable_name, + period=periods.period(period), + option=options, + ) + + self.entity.check_variable_defined_for_entity(calculate.variable) + self.check_period_validity(calculate.variable, calculate.period) + + if not isinstance(calculate.option, Sequence): + return self.simulation.calculate( + calculate.variable, + calculate.period, + ) - if options is None: - options = [] + if config.ADD in calculate.option: + return self.simulation.calculate_add( + calculate.variable, + calculate.period, + ) + + if config.DIVIDE in calculate.option: + return self.simulation.calculate_divide( + calculate.variable, + calculate.period, + ) - if config.ADD in options and config.DIVIDE in options: - raise ValueError('Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {})'.format(variable_name).encode('utf-8')) - elif config.ADD in options: - return self.simulation.calculate_add(variable_name, period) - elif config.DIVIDE in options: - return self.simulation.calculate_divide(variable_name, period) - else: - return self.simulation.calculate(variable_name, period) + raise ValueError( + f"Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {variable_name})".encode(), + ) # Helpers - def get_holder(self, variable_name): + def get_holder(self, variable_name: str) -> holders.Holder: + self.entity.check_variable_defined_for_entity(variable_name) holder = self._holders.get(variable_name) - if holder: return holder - - variable = self.entity.variables.isdefined().get(variable_name) - self._holders[variable_name] = holder = Holder(variable, self) + variable = self.entity.get_variable(variable_name) + self._holders[variable_name] = holder = holders.Holder(variable, self) return holder - def get_memory_usage(self, variables = None): + def get_memory_usage( + self, + variables: Sequence[str] | None = None, + ) -> MemoryUsageByVariable: holders_memory_usage = { variable_name: holder.get_memory_usage() for variable_name, holder in self._holders.items() if variables is None or variable_name in variables - } + } total_memory_usage = sum( - holder_memory_usage['total_nb_bytes'] for holder_memory_usage in holders_memory_usage.values() - ) + holder_memory_usage["total_nb_bytes"] + for holder_memory_usage in holders_memory_usage.values() + ) - return dict( - total_nb_bytes = total_memory_usage, - by_variable = holders_memory_usage - ) + return MemoryUsageByVariable( + { + "total_nb_bytes": total_memory_usage, + "by_variable": holders_memory_usage, + }, + ) @projectors.projectable - def has_role(self, role): - """ - Check if a person has a given role within its :any:`GroupEntity` + def has_role(self, role: Role) -> Array[bool] | None: + """Check if a person has a given role within its `GroupEntity`. - Example: + Example: + >>> person.has_role(Household.CHILD) + >>> array([False]) - >>> person.has_role(Household.CHILD) - >>> array([False]) """ - entities.check_role_validity(role) + if self.simulation is None: + return None + + self.entity.check_role_validity(role) + group_population = self.simulation.get_population(role.entity.plural) + if role.subroles: - return numpy.logical_or.reduce([group_population.members_role == subrole for subrole in role.subroles]) - else: - return group_population.members_role == role + return numpy.logical_or.reduce( + [group_population.members_role == subrole for subrole in role.subroles], + ) + + return group_population.members_role == role @projectors.projectable - def value_from_partner(self, array, entity, role): + def value_from_partner( + self, + array: Array[float], + entity: projectors.Projector, + role: Role, + ) -> Array[float] | None: self.check_array_compatible_with_entity(array) - entities.check_role_validity(role) + self.entity.check_role_validity(role) - if not role.subroles or not len(role.subroles) == 2: - raise Exception('Projection to partner is only implemented for roles having exactly two subroles.') + if not role.subroles or len(role.subroles) != 2: + msg = "Projection to partner is only implemented for roles having exactly two subroles." + raise Exception( + msg, + ) [subrole_1, subrole_2] = role.subroles value_subrole_1 = entity.value_from_person(array, subrole_1) @@ -144,28 +228,39 @@ def value_from_partner(self, array, entity, role): return numpy.select( [self.has_role(subrole_1), self.has_role(subrole_2)], [value_subrole_2, value_subrole_1], - ) + ) @projectors.projectable - def get_rank(self, entity, criteria, condition = True): - """ - Get the rank of a person within an entity according to a criteria. + def get_rank( + self, + entity: Population, + criteria: Array[float], + condition: bool = True, + ) -> Array[int]: + """Get the rank of a person within an entity according to a criteria. The person with rank 0 has the minimum value of criteria. If condition is specified, then the persons who don't respect it are not taken into account and their rank is -1. Example: - - >>> age = person('age', period) # e.g [32, 34, 2, 8, 1] + >>> age = person("age", period) # e.g [32, 34, 2, 8, 1] >>> person.get_rank(household, age) >>> [3, 4, 0, 2, 1] - >>> is_child = person.has_role(Household.CHILD) # [False, False, True, True, True] - >>> person.get_rank(household, - age, condition = is_child) # Sort in reverse order so that the eldest child gets the rank 0. + >>> is_child = person.has_role( + ... Household.CHILD + ... ) # [False, False, True, True, True] + >>> person.get_rank( + ... household, -age, condition=is_child + ... ) # Sort in reverse order so that the eldest child gets the rank 0. >>> [-1, -1, 1, 0, 2] - """ + """ # If entity is for instance 'person.household', we get the reference entity 'household' behind the projector - entity = entity if not isinstance(entity, Projector) else entity.reference_entity + entity = ( + entity + if not isinstance(entity, projectors.Projector) + else entity.reference_entity + ) positions = entity.members_position biggest_entity_size = numpy.max(positions) + 1 @@ -173,10 +268,12 @@ def get_rank(self, entity, criteria, condition = True): ids = entity.members_entity_id # Matrix: the value in line i and column j is the value of criteria for the jth person of the ith entity - matrix = numpy.asarray([ - entity.value_nth_person(k, filtered_criteria, default = numpy.inf) - for k in range(biggest_entity_size) - ]).transpose() + matrix = numpy.asarray( + [ + entity.value_nth_person(k, filtered_criteria, default=numpy.inf) + for k in range(biggest_entity_size) + ], + ).transpose() # We double-argsort all lines of the matrix. # Double-argsorting gets the rank of each value once sorted @@ -188,3 +285,14 @@ def get_rank(self, entity, criteria, condition = True): # Return -1 for the persons who don't respect the condition return numpy.where(condition, result, -1) + + +class Calculate(NamedTuple): + variable: str + period: Period + option: Sequence[str] | None + + +class MemoryUsageByVariable(TypedDict, total=False): + by_variable: dict[str, holders.MemoryUsage] + total_nb_bytes: int diff --git a/openfisca_core/projectors/__init__.py b/openfisca_core/projectors/__init__.py index 02982bf982..28776e3cf9 100644 --- a/openfisca_core/projectors/__init__.py +++ b/openfisca_core/projectors/__init__.py @@ -21,8 +21,19 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .helpers import projectable, get_projector_from_shortcut # noqa: F401 -from .projector import Projector # noqa: F401 -from .entity_to_person_projector import EntityToPersonProjector # noqa: F401 -from .first_person_to_entity_projector import FirstPersonToEntityProjector # noqa: F401 -from .unique_role_to_entity_projector import UniqueRoleToEntityProjector # noqa: F401 +from . import typing +from .entity_to_person_projector import EntityToPersonProjector +from .first_person_to_entity_projector import FirstPersonToEntityProjector +from .helpers import get_projector_from_shortcut, projectable +from .projector import Projector +from .unique_role_to_entity_projector import UniqueRoleToEntityProjector + +__all__ = [ + "EntityToPersonProjector", + "FirstPersonToEntityProjector", + "get_projector_from_shortcut", + "projectable", + "Projector", + "UniqueRoleToEntityProjector", + "typing", +] diff --git a/openfisca_core/projectors/entity_to_person_projector.py b/openfisca_core/projectors/entity_to_person_projector.py index 3990233c70..392fda08a1 100644 --- a/openfisca_core/projectors/entity_to_person_projector.py +++ b/openfisca_core/projectors/entity_to_person_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class EntityToPersonProjector(Projector): """For instance person.family.""" - def __init__(self, entity, parent = None): + def __init__(self, entity, parent=None) -> None: self.reference_entity = entity self.parent = parent diff --git a/openfisca_core/projectors/first_person_to_entity_projector.py b/openfisca_core/projectors/first_person_to_entity_projector.py index 3912ccef1e..d986460cdc 100644 --- a/openfisca_core/projectors/first_person_to_entity_projector.py +++ b/openfisca_core/projectors/first_person_to_entity_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class FirstPersonToEntityProjector(Projector): """For instance famille.first_person.""" - def __init__(self, entity, parent = None): + def __init__(self, entity, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/projectors/helpers.py b/openfisca_core/projectors/helpers.py index 502eee1dfb..4c7712106a 100644 --- a/openfisca_core/projectors/helpers.py +++ b/openfisca_core/projectors/helpers.py @@ -1,23 +1,140 @@ -from openfisca_core import projectors +from __future__ import annotations + +from collections.abc import Mapping + +from openfisca_core.types import GroupEntity, Role, SingleEntity + +from openfisca_core import entities, projectors + +from .typing import GroupPopulation, Population def projectable(function): - """ - Decorator to indicate that when called on a projector, the outcome of the function must be projected. + """Decorator to indicate that when called on a projector, the outcome of the function must be projected. For instance person.household.sum(...) must be projected on person, while it would not make sense for person.household.get_holder. """ function.projectable = True return function -def get_projector_from_shortcut(population, shortcut, parent = None): - if population.entity.is_person: - if shortcut in population.simulation.populations: - entity_2 = population.simulation.populations[shortcut] - return projectors.EntityToPersonProjector(entity_2, parent) - else: - if shortcut == 'first_person': - return projectors.FirstPersonToEntityProjector(population, parent) - role = next((role for role in population.entity.flattened_roles if (role.max == 1) and (role.key == shortcut)), None) - if role: +def get_projector_from_shortcut( + population: Population | GroupPopulation, + shortcut: str, + parent: projectors.Projector | None = None, +) -> projectors.Projector | None: + """Get a projector from a shortcut. + + Projectors are used to project an invidividual Population's or a + collective GroupPopulation's on to other populations. + + The currently available cases are projecting: + - from an invidivual to a group + - from a group to an individual + - from a group to an individual with a unique role + + For example, if there are two entities, person (Entity) and household + (GroupEntity), on which calculations can be run (Population and + GroupPopulation respectively), and there is a Variable "rent" defined for + the household entity, then `person.household("rent")` will assign a rent to + every person within that household. + + Behind the scenes, this is done thanks to a Projector, and this function is + used to find the appropriate one for each case. In the above example, the + `shortcut` argument would be "household", and the `population` argument + whould be the Population linked to the "person" Entity in the context + of a specific Simulation and TaxBenefitSystem. + + Args: + population (Population | GroupPopulation): Where to project from. + shortcut (str): Where to project to. + parent: ??? + + Examples: + >>> from openfisca_core import ( + ... entities, + ... populations, + ... simulations, + ... taxbenefitsystems, + ... ) + + >>> entity = entities.Entity("person", "", "", "") + + >>> group_entity_1 = entities.GroupEntity("family", "", "", "", []) + + >>> roles = [ + ... {"key": "person", "max": 1}, + ... {"key": "animal", "subroles": ["cat", "dog"]}, + ... ] + + >>> group_entity_2 = entities.GroupEntity("household", "", "", "", roles) + + >>> population = populations.Population(entity) + + >>> group_population_1 = populations.GroupPopulation(group_entity_1, []) + + >>> group_population_2 = populations.GroupPopulation(group_entity_2, []) + + >>> populations = { + ... entity.key: population, + ... group_entity_1.key: group_population_1, + ... group_entity_2.key: group_population_2, + ... } + + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem( + ... [entity, group_entity_1, group_entity_2] + ... ) + + >>> simulation = simulations.Simulation(tax_benefit_system, populations) + + >>> get_projector_from_shortcut(population, "person") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(population, "family") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(population, "household") + <...EntityToPersonProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "first_person") + <...FirstPersonToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "person") + <...UniqueRoleToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "cat") + <...UniqueRoleToEntityProjector object at ...> + + >>> get_projector_from_shortcut(group_population_2, "dog") + <...UniqueRoleToEntityProjector object at ...> + + """ + entity: SingleEntity | GroupEntity = population.entity + + if isinstance(entity, entities.Entity): + populations: Mapping[ + str, + Population | GroupPopulation, + ] = population.simulation.populations + + if shortcut not in populations: + return None + + return projectors.EntityToPersonProjector(populations[shortcut], parent) + + if shortcut == "first_person": + return projectors.FirstPersonToEntityProjector(population, parent) + + if isinstance(entity, entities.GroupEntity): + role: Role | None = entities.find_role(entity.roles, shortcut, total=1) + + if role is not None: return projectors.UniqueRoleToEntityProjector(population, role, parent) + + if shortcut in entity.containing_entities: + projector: projectors.Projector = getattr( + projectors.FirstPersonToEntityProjector(population, parent), + shortcut, + ) + return projector + + return None diff --git a/openfisca_core/projectors/projector.py b/openfisca_core/projectors/projector.py index 41138813b5..37881201dc 100644 --- a/openfisca_core/projectors/projector.py +++ b/openfisca_core/projectors/projector.py @@ -6,12 +6,16 @@ class Projector: parent = None def __getattr__(self, attribute): - projector = helpers.get_projector_from_shortcut(self.reference_entity, attribute, parent = self) + projector = helpers.get_projector_from_shortcut( + self.reference_entity, + attribute, + parent=self, + ) if projector: return projector reference_attr = getattr(self.reference_entity, attribute) - if not hasattr(reference_attr, 'projectable'): + if not hasattr(reference_attr, "projectable"): return reference_attr def projector_function(*args, **kwargs): @@ -28,8 +32,7 @@ def transform_and_bubble_up(self, result): transformed_result = self.transform(result) if self.parent is None: return transformed_result - else: - return self.parent.transform_and_bubble_up(transformed_result) + return self.parent.transform_and_bubble_up(transformed_result) def transform(self, result): return NotImplementedError() diff --git a/openfisca_core/projectors/typing.py b/openfisca_core/projectors/typing.py new file mode 100644 index 0000000000..a49bc96621 --- /dev/null +++ b/openfisca_core/projectors/typing.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Protocol + +from openfisca_core.types import GroupEntity, SingleEntity + + +class Population(Protocol): + @property + def entity(self) -> SingleEntity: ... + + @property + def simulation(self) -> Simulation: ... + + +class GroupPopulation(Protocol): + @property + def entity(self) -> GroupEntity: ... + + @property + def simulation(self) -> Simulation: ... + + +class Simulation(Protocol): + @property + def populations(self) -> Mapping[str, Population | GroupPopulation]: ... diff --git a/openfisca_core/projectors/unique_role_to_entity_projector.py b/openfisca_core/projectors/unique_role_to_entity_projector.py index 25b3258dc3..c565484339 100644 --- a/openfisca_core/projectors/unique_role_to_entity_projector.py +++ b/openfisca_core/projectors/unique_role_to_entity_projector.py @@ -1,10 +1,10 @@ -from openfisca_core.projectors import Projector +from .projector import Projector class UniqueRoleToEntityProjector(Projector): - """ For instance famille.declarant_principal.""" + """For instance famille.declarant_principal.""" - def __init__(self, entity, role, parent = None): + def __init__(self, entity, role, parent=None) -> None: self.target_entity = entity self.reference_entity = entity.members self.parent = parent diff --git a/openfisca_core/reforms/reform.py b/openfisca_core/reforms/reform.py index 1ba0be30a8..76e7152334 100644 --- a/openfisca_core/reforms/reform.py +++ b/openfisca_core/reforms/reform.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy from openfisca_core.parameters import ParameterNode @@ -5,25 +7,22 @@ class Reform(TaxBenefitSystem): - """ - A modified TaxBenefitSystem - + """A modified TaxBenefitSystem. - All reforms must subclass `Reform` and implement a method `apply()`. + All reforms must subclass `Reform` and implement a method `apply()`. - In this method, the reform can add or replace variables and call :any:`modify_parameters` to modify the parameters of the legislation. - - Example: + In this method, the reform can add or replace variables and call `modify_parameters` to modify the parameters of the legislation. + Example: >>> from openfisca_core import reforms >>> from openfisca_core.parameters import load_parameter_file >>> >>> def modify_my_parameters(parameters): - >>> # Add new parameters + >>> # Add new parameters >>> new_parameters = load_parameter_file(name='reform_name', file_path='path_to_yaml_file.yaml') >>> parameters.add_child('reform_name', new_parameters) >>> - >>> # Update a value + >>> # Update a value >>> parameters.taxes.some_tax.some_param.update(period=some_period, value=1000.0) >>> >>> return parameters @@ -33,13 +32,13 @@ class Reform(TaxBenefitSystem): >>> self.add_variable(some_variable) >>> self.update_variable(some_other_variable) >>> self.modify_parameters(modifier_function = modify_my_parameters) + """ + name = None - def __init__(self, baseline): - """ - :param baseline: Baseline TaxBenefitSystem. - """ + def __init__(self, baseline) -> None: + """:param baseline: Baseline TaxBenefitSystem.""" super().__init__(baseline.entities) self.baseline = baseline self.parameters = baseline.parameters @@ -47,8 +46,9 @@ def __init__(self, baseline): self.variables = baseline.variables.copy() self.decomposition_file_path = baseline.decomposition_file_path self.key = self.__class__.__name__ - if not hasattr(self, 'apply'): - raise Exception("Reform {} must define an `apply` function".format(self.key)) + if not hasattr(self, "apply"): + msg = f"Reform {self.key} must define an `apply` function" + raise Exception(msg) self.apply() def __getattr__(self, attribute): @@ -57,27 +57,30 @@ def __getattr__(self, attribute): @property def full_key(self): key = self.key - assert key is not None, 'key was not set for reform {} (name: {!r})'.format(self, self.name) - if self.baseline is not None and hasattr(self.baseline, 'key'): + assert ( + key is not None + ), f"key was not set for reform {self} (name: {self.name!r})" + if self.baseline is not None and hasattr(self.baseline, "key"): baseline_full_key = self.baseline.full_key - key = '.'.join([baseline_full_key, key]) + key = f"{baseline_full_key}.{key}" return key def modify_parameters(self, modifier_function): - """ - Make modifications on the parameters of the legislation + """Make modifications on the parameters of the legislation. Call this function in `apply()` if the reform asks for legislation parameter modifications. - :param modifier_function: A function that takes an object of type :any:`ParameterNode` and should return an object of the same type. + Args: + modifier_function: A function that takes a :obj:`.ParameterNode` and should return an object of the same type. + """ baseline_parameters = self.baseline.parameters baseline_parameters_copy = copy.deepcopy(baseline_parameters) reform_parameters = modifier_function(baseline_parameters_copy) if not isinstance(reform_parameters, ParameterNode): return ValueError( - 'modifier_function {} in module {} must return a ParameterNode' - .format(modifier_function.__name__, modifier_function.__module__,) - ) + f"modifier_function {modifier_function.__name__} in module {modifier_function.__module__} must return a ParameterNode", + ) self.parameters = reform_parameters self._parameters_at_instant_cache = {} + return None diff --git a/openfisca_core/scripts/__init__.py b/openfisca_core/scripts/__init__.py index 9e0a3b67bc..e9080f2381 100644 --- a/openfisca_core/scripts/__init__.py +++ b/openfisca_core/scripts/__init__.py @@ -1,18 +1,33 @@ -# -*- coding: utf-8 -*- - -import traceback import importlib import logging import pkgutil +import traceback from os import linesep log = logging.getLogger(__name__) def add_tax_benefit_system_arguments(parser): - parser.add_argument('-c', '--country-package', action = 'store', help = 'country package to use. If not provided, an automatic detection will be attempted by scanning the python packages installed in your environment which name contains the word "openfisca".') - parser.add_argument('-e', '--extensions', action = 'store', help = 'extensions to load', nargs = '*') - parser.add_argument('-r', '--reforms', action = 'store', help = 'reforms to apply to the country package', nargs = '*') + parser.add_argument( + "-c", + "--country-package", + action="store", + help='country package to use. If not provided, an automatic detection will be attempted by scanning the python packages installed in your environment which name contains the word "openfisca".', + ) + parser.add_argument( + "-e", + "--extensions", + action="store", + help="extensions to load", + nargs="*", + ) + parser.add_argument( + "-r", + "--reforms", + action="store", + help="reforms to apply to the country package", + nargs="*", + ) return parser @@ -23,14 +38,21 @@ def build_tax_benefit_system(country_package_name, extensions, reforms): try: country_package = importlib.import_module(country_package_name) except ImportError: - message = linesep.join([traceback.format_exc(), - 'Could not import module `{}`.'.format(country_package_name), - 'Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.', - 'See more at .']) + message = linesep.join( + [ + traceback.format_exc(), + f"Could not import module `{country_package_name}`.", + "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", + "See more at .", + ], + ) raise ImportError(message) - if not hasattr(country_package, 'CountryTaxBenefitSystem'): - raise ImportError('`{}` does not seem to be a valid Openfisca country package.'.format(country_package_name)) + if not hasattr(country_package, "CountryTaxBenefitSystem"): + msg = f"`{country_package_name}` does not seem to be a valid Openfisca country package." + raise ImportError( + msg, + ) country_package = importlib.import_module(country_package_name) tax_benefit_system = country_package.CountryTaxBenefitSystem() @@ -54,19 +76,31 @@ def detect_country_package(): installed_country_packages = [] for module_description in pkgutil.iter_modules(): module_name = module_description[1] - if 'openfisca' in module_name.lower(): + if "openfisca" in module_name.lower(): try: module = importlib.import_module(module_name) except ImportError: - message = linesep.join([traceback.format_exc(), - 'Could not import module `{}`.'.format(module_name), - 'Look at the stack trace above to determine the error that stopped installed modules detection.']) + message = linesep.join( + [ + traceback.format_exc(), + f"Could not import module `{module_name}`.", + "Look at the stack trace above to determine the error that stopped installed modules detection.", + ], + ) raise ImportError(message) - if hasattr(module, 'CountryTaxBenefitSystem'): + if hasattr(module, "CountryTaxBenefitSystem"): installed_country_packages.append(module_name) if len(installed_country_packages) == 0: - raise ImportError('No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option.') + msg = "No country package has been detected on your environment. If your country package is installed but not detected, please use the --country-package option." + raise ImportError( + msg, + ) if len(installed_country_packages) > 1: - log.warning('Several country packages detected : `{}`. Using `{}` by default. To use another package, please use the --country-package option.'.format(', '.join(installed_country_packages), installed_country_packages[0])) + log.warning( + "Several country packages detected : `{}`. Using `{}` by default. To use another package, please use the --country-package option.".format( + ", ".join(installed_country_packages), + installed_country_packages[0], + ), + ) return installed_country_packages[0] diff --git a/openfisca_core/scripts/find_placeholders.py b/openfisca_core/scripts/find_placeholders.py index b14fd5fea9..b7b5a81969 100644 --- a/openfisca_core/scripts/find_placeholders.py +++ b/openfisca_core/scripts/find_placeholders.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 -import os import fnmatch +import os import sys from bs4 import BeautifulSoup @@ -10,42 +9,37 @@ def find_param_files(input_dir): param_files = [] - for root, dirnames, filenames in os.walk(input_dir): - for filename in fnmatch.filter(filenames, '*.xml'): + for root, _dirnames, filenames in os.walk(input_dir): + for filename in fnmatch.filter(filenames, "*.xml"): param_files.append(os.path.join(root, filename)) return param_files def find_placeholders(filename_input): - with open(filename_input, 'r') as f: + with open(filename_input) as f: xml_content = f.read() xml_parsed = BeautifulSoup(xml_content, "lxml-xml") - placeholders = xml_parsed.find_all('PLACEHOLDER') + placeholders = xml_parsed.find_all("PLACEHOLDER") output_list = [] for placeholder in placeholders: parent_list = list(placeholder.parents)[:-1] - path = '.'.join([p.attrs['code'] for p in parent_list if 'code' in p.attrs][::-1]) + path = ".".join( + [p.attrs["code"] for p in parent_list if "code" in p.attrs][::-1], + ) - deb = placeholder.attrs['deb'] + deb = placeholder.attrs["deb"] output_list.append((deb, path)) - output_list = sorted(output_list, key = lambda x: x[0]) - - return output_list + return sorted(output_list, key=lambda x: x[0]) if __name__ == "__main__": - print('''find_placeholders.py : Find nodes PLACEHOLDER in xml parameter files -Usage : - python find_placeholders /dir/to/search -''') - - assert(len(sys.argv) == 2) + assert len(sys.argv) == 2 input_dir = sys.argv[1] param_files = find_param_files(input_dir) @@ -53,9 +47,5 @@ def find_placeholders(filename_input): for filename_input in param_files: output_list = find_placeholders(filename_input) - print('File {}'.format(filename_input)) - - for deb, path in output_list: - print('{} {}'.format(deb, path)) - - print('\n') + for _deb, _path in output_list: + pass diff --git a/openfisca_core/scripts/measure_numpy_condition_notations.py b/openfisca_core/scripts/measure_numpy_condition_notations.py index 3132205a1c..65e48f6e2c 100755 --- a/openfisca_core/scripts/measure_numpy_condition_notations.py +++ b/openfisca_core/scripts/measure_numpy_condition_notations.py @@ -1,33 +1,30 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 -""" -Measure and compare different vectorial condition notations: +"""Measure and compare different vectorial condition notations: - using multiplication notation: (choice == 1) * choice_1_value + (choice == 2) * choice_2_value -- using np.select: the same than multiplication but more idiomatic like a "switch" control-flow statement -- using np.fromiter: iterates in Python over the array and calculates lazily only the required values +- using numpy.select: the same than multiplication but more idiomatic like a "switch" control-flow statement +- using numpy.fromiter: iterates in Python over the array and calculates lazily only the required values. The aim of this script is to compare the time taken by the calculation of the values """ -from contextlib import contextmanager + import argparse import sys import time +from contextlib import contextmanager -import numpy as np - +import numpy args = None @contextmanager def measure_time(title): - t1 = time.time() + time.time() yield - t2 = time.time() - print('{}\t: {:.8f} seconds elapsed'.format(title, t2 - t1)) + time.time() def switch_fromiter(conditions, function_by_condition, dtype): @@ -39,34 +36,28 @@ def get_or_store_value(condition): value_by_condition[condition] = value return value_by_condition[condition] - return np.fromiter( - ( - get_or_store_value(condition) - for condition in conditions - ), + return numpy.fromiter( + (get_or_store_value(condition) for condition in conditions), dtype, - ) + ) def switch_select(conditions, value_by_condition): - condlist = [ - conditions == condition - for condition in value_by_condition.keys() - ] - return np.select(condlist, value_by_condition.values()) + condlist = [conditions == condition for condition in value_by_condition] + return numpy.select(condlist, value_by_condition.values()) -def calculate_choice_1_value(): +def calculate_choice_1_value() -> int: time.sleep(args.calculate_time) return 80 -def calculate_choice_2_value(): +def calculate_choice_2_value() -> int: time.sleep(args.calculate_time) return 90 -def calculate_choice_3_value(): +def calculate_choice_3_value() -> int: time.sleep(args.calculate_time) return 95 @@ -75,61 +66,70 @@ def test_multiplication(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_3_value() - result = (choice == 1) * choice_1_value + (choice == 2) * choice_2_value + (choice == 3) * choice_3_value - return result + return ( + (choice == 1) * choice_1_value + + (choice == 2) * choice_2_value + + (choice == 3) * choice_3_value + ) def test_switch_fromiter(choice): - result = switch_fromiter( + return switch_fromiter( choice, { 1: calculate_choice_1_value, 2: calculate_choice_2_value, 3: calculate_choice_3_value, - }, - dtype = np.int, - ) - return result + }, + dtype=int, + ) def test_switch_select(choice): choice_1_value = calculate_choice_1_value() choice_2_value = calculate_choice_2_value() choice_3_value = calculate_choice_2_value() - result = switch_select( + return switch_select( choice, { 1: choice_1_value, 2: choice_2_value, 3: choice_3_value, - }, - ) - return result + }, + ) -def test_all_notations(): +def test_all_notations() -> None: # choice is an array with 1 and 2 items like [2, 1, ..., 1, 2] - choice = np.random.randint(2, size = args.array_length) + 1 + choice = numpy.random.randint(2, size=args.array_length) + 1 - with measure_time('multiplication'): + with measure_time("multiplication"): test_multiplication(choice) - with measure_time('switch_select'): + with measure_time("switch_select"): test_switch_select(choice) - with measure_time('switch_fromiter'): + with measure_time("switch_fromiter"): test_switch_fromiter(choice) -def main(): - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument('--array-length', default = 1000, type = int, help = "length of the array") - parser.add_argument('--calculate-time', default = 0.1, type = float, - help = "time taken by the calculation in seconds") +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--array-length", + default=1000, + type=int, + help="length of the array", + ) + parser.add_argument( + "--calculate-time", + default=0.1, + type=float, + help="time taken by the calculation in seconds", + ) global args args = parser.parse_args() - print(args) test_all_notations() diff --git a/openfisca_core/scripts/measure_performances.py b/openfisca_core/scripts/measure_performances.py index 1d84ddd585..48b99c93f8 100644 --- a/openfisca_core/scripts/measure_performances.py +++ b/openfisca_core/scripts/measure_performances.py @@ -1,35 +1,32 @@ #! /usr/bin/env python -# -*- coding: utf-8 -*- # flake8: noqa T001 """Measure performances of a basic tax-benefit system to compare to other OpenFisca implementations.""" + import argparse import logging import sys import time -import numpy as np +import numpy from numpy.core.defchararray import startswith from openfisca_core import periods, simulations -from openfisca_core.periods import ETERNITY from openfisca_core.entities import build_entity -from openfisca_core.variables import Variable +from openfisca_core.periods import DateUnit from openfisca_core.taxbenefitsystems import TaxBenefitSystem from openfisca_core.tools import assert_near - +from openfisca_core.variables import Variable args = None def timeit(method): def timed(*args, **kwargs): - start_time = time.time() - result = method(*args, **kwargs) + time.time() + return method(*args, **kwargs) # print '%r (%r, %r) %2.9f s' % (method.__name__, args, kw, time.time() - start_time) - print('{:2.6f} s'.format(time.time() - start_time)) - return result return timed @@ -37,31 +34,31 @@ def timed(*args, **kwargs): # Entities Famille = build_entity( - key = "famille", - plural = "familles", - label = 'Famille', - roles = [ + key="famille", + plural="familles", + label="Famille", + roles=[ { - 'key': 'parent', - 'plural': 'parents', - 'label': 'Parents', - 'subroles': ['demandeur', 'conjoint'] - }, + "key": "parent", + "plural": "parents", + "label": "Parents", + "subroles": ["demandeur", "conjoint"], + }, { - 'key': 'enfant', - 'plural': 'enfants', - 'label': 'Enfants', - } - ] - ) + "key": "enfant", + "plural": "enfants", + "label": "Enfants", + }, + ], +) Individu = build_entity( - key = "individu", - plural = "individus", - label = 'Individu', - is_person = True, - ) + key="individu", + plural="individus", + label="Individu", + is_person=True, +) # Input variables @@ -73,16 +70,16 @@ class age_en_mois(Variable): class birth(Variable): - value_type = 'Date' + value_type = "Date" entity = Individu label = "Date de naissance" class city_code(Variable): - value_type = 'FixedStr' + value_type = "FixedStr" max_length = 5 entity = Famille - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY label = """Code INSEE "city_code" de la commune de résidence de la famille""" @@ -94,30 +91,33 @@ class salaire_brut(Variable): # Calculated variables + class age(Variable): value_type = int entity = Individu label = "Âge (en nombre d'années)" def formula(self, simulation, period): - birth = simulation.get_array('birth', period) + birth = simulation.get_array("birth", period) if birth is None: - age_en_mois = simulation.get_array('age_en_mois', period) + age_en_mois = simulation.get_array("age_en_mois", period) if age_en_mois is not None: return age_en_mois // 12 - birth = simulation.calculate('birth', period) - return (np.datetime64(period.date) - birth).astype('timedelta64[Y]') + birth = simulation.calculate("birth", period) + return (numpy.datetime64(period.date) - birth).astype("timedelta64[Y]") class dom_tom(Variable): - value_type = 'Bool' + value_type = "Bool" entity = Famille label = "La famille habite-t-elle les DOM-TOM ?" def formula(self, simulation, period): - period = period.start.period('year').offset('first-of') - city_code = simulation.calculate('city_code', period) - return np.logical_or(startswith(city_code, '97'), startswith(city_code, '98')) + period = period.start.period(DateUnit.YEAR).offset("first-of") + city_code = simulation.calculate("city_code", period) + return numpy.logical_or( + startswith(city_code, "97"), startswith(city_code, "98") + ) class revenu_disponible(Variable): @@ -126,9 +126,9 @@ class revenu_disponible(Variable): label = "Revenu disponible de l'individu" def formula(self, simulation, period): - period = period.start.period('year').offset('first-of') - rsa = simulation.calculate('rsa', period) - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.YEAR).offset("first-of") + rsa = simulation.calculate("rsa", period) + salaire_imposable = simulation.calculate("salaire_imposable", period) return rsa + salaire_imposable * 0.7 @@ -138,18 +138,18 @@ class rsa(Variable): label = "RSA" def formula_2010_01_01(self, simulation, period): - period = period.start.period('month').offset('first-of') - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.MONTH).offset("first-of") + salaire_imposable = simulation.calculate("salaire_imposable", period) return (salaire_imposable < 500) * 100.0 def formula_2011_01_01(self, simulation, period): - period = period.start.period('month').offset('first-of') - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.MONTH).offset("first-of") + salaire_imposable = simulation.calculate("salaire_imposable", period) return (salaire_imposable < 500) * 200.0 def formula_2013_01_01(self, simulation, period): - period = period.start.period('month').offset('first-of') - salaire_imposable = simulation.calculate('salaire_imposable', period) + period = period.start.period(DateUnit.MONTH).offset("first-of") + salaire_imposable = simulation.calculate("salaire_imposable", period) return (salaire_imposable < 500) * 300 @@ -158,10 +158,10 @@ class salaire_imposable(Variable): entity = Individu label = "Salaire imposable" - def formula(individu, period): - period = period.start.period('year').offset('first-of') - dom_tom = individu.famille('dom_tom', period) - salaire_net = individu('salaire_net', period) + def formula(self, period): + period = period.start.period(DateUnit.YEAR).offset("first-of") + dom_tom = self.famille("dom_tom", period) + salaire_net = self("salaire_net", period) return salaire_net * 0.9 - 100 * dom_tom @@ -171,8 +171,8 @@ class salaire_net(Variable): label = "Salaire net" def formula(self, simulation, period): - period = period.start.period('year').offset('first-of') - salaire_brut = simulation.calculate('salaire_brut', period) + period = period.start.period(DateUnit.YEAR).offset("first-of") + salaire_brut = simulation.calculate("salaire_brut", period) return salaire_brut * 0.8 @@ -180,13 +180,26 @@ def formula(self, simulation, period): tax_benefit_system = TaxBenefitSystem([Famille, Individu]) -tax_benefit_system.add_variables(age_en_mois, birth, city_code, salaire_brut, age, - dom_tom, revenu_disponible, rsa, salaire_imposable, salaire_net) +tax_benefit_system.add_variables( + age_en_mois, + birth, + city_code, + salaire_brut, + age, + dom_tom, + revenu_disponible, + rsa, + salaire_imposable, + salaire_net, +) @timeit -def check_revenu_disponible(year, city_code, expected_revenu_disponible): - simulation = simulations.Simulation(period = periods.period(year), tax_benefit_system = tax_benefit_system) +def check_revenu_disponible(year, city_code, expected_revenu_disponible) -> None: + simulation = simulations.Simulation( + period=periods.period(year), + tax_benefit_system=tax_benefit_system, + ) famille = simulation.populations["famille"] famille.count = 3 famille.roles_count = 2 @@ -194,31 +207,84 @@ def check_revenu_disponible(year, city_code, expected_revenu_disponible): individu = simulation.populations["individu"] individu.count = 6 individu.step_size = 2 - simulation.get_or_new_holder("city_code").array = np.array([city_code, city_code, city_code]) - famille.members_entity_id = np.array([0, 0, 1, 1, 2, 2]) - simulation.get_or_new_holder("salaire_brut").array = np.array([0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0]) - revenu_disponible = simulation.calculate('revenu_disponible') - assert_near(revenu_disponible, expected_revenu_disponible, absolute_error_margin = 0.005) + simulation.get_or_new_holder("city_code").array = numpy.array( + [city_code, city_code, city_code], + ) + famille.members_entity_id = numpy.array([0, 0, 1, 1, 2, 2]) + simulation.get_or_new_holder("salaire_brut").array = numpy.array( + [0.0, 0.0, 50000.0, 0.0, 100000.0, 0.0], + ) + revenu_disponible = simulation.calculate("revenu_disponible") + assert_near( + revenu_disponible, + expected_revenu_disponible, + absolute_error_margin=0.005, + ) -def main(): - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument('-v', '--verbose', action = 'store_true', default = False, help = "increase output verbosity") +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="increase output verbosity", + ) global args args = parser.parse_args() - logging.basicConfig(level = logging.DEBUG if args.verbose else logging.WARNING, stream = sys.stdout) - - check_revenu_disponible(2009, '75101', np.array([0, 0, 25200, 0, 50400, 0])) - check_revenu_disponible(2010, '75101', np.array([1200, 1200, 25200, 1200, 50400, 1200])) - check_revenu_disponible(2011, '75101', np.array([2400, 2400, 25200, 2400, 50400, 2400])) - check_revenu_disponible(2012, '75101', np.array([2400, 2400, 25200, 2400, 50400, 2400])) - check_revenu_disponible(2013, '75101', np.array([3600, 3600, 25200, 3600, 50400, 3600])) - - check_revenu_disponible(2009, '97123', np.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0])) - check_revenu_disponible(2010, '97123', np.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0])) - check_revenu_disponible(2011, '98456', np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0])) - check_revenu_disponible(2012, '98456', np.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0])) - check_revenu_disponible(2013, '98456', np.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0])) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, + ) + + check_revenu_disponible(2009, "75101", numpy.array([0, 0, 25200, 0, 50400, 0])) + check_revenu_disponible( + 2010, + "75101", + numpy.array([1200, 1200, 25200, 1200, 50400, 1200]), + ) + check_revenu_disponible( + 2011, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), + ) + check_revenu_disponible( + 2012, + "75101", + numpy.array([2400, 2400, 25200, 2400, 50400, 2400]), + ) + check_revenu_disponible( + 2013, + "75101", + numpy.array([3600, 3600, 25200, 3600, 50400, 3600]), + ) + + check_revenu_disponible( + 2009, + "97123", + numpy.array([-70.0, -70.0, 25130.0, -70.0, 50330.0, -70.0]), + ) + check_revenu_disponible( + 2010, + "97123", + numpy.array([1130.0, 1130.0, 25130.0, 1130.0, 50330.0, 1130.0]), + ) + check_revenu_disponible( + 2011, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), + ) + check_revenu_disponible( + 2012, + "98456", + numpy.array([2330.0, 2330.0, 25130.0, 2330.0, 50330.0, 2330.0]), + ) + check_revenu_disponible( + 2013, + "98456", + numpy.array([3530.0, 3530.0, 25130.0, 3530.0, 50330.0, 3530.0]), + ) if __name__ == "__main__": diff --git a/openfisca_core/scripts/measure_performances_fancy_indexing.py b/openfisca_core/scripts/measure_performances_fancy_indexing.py index 894250ef54..7c261e2fe3 100644 --- a/openfisca_core/scripts/measure_performances_fancy_indexing.py +++ b/openfisca_core/scripts/measure_performances_fancy_indexing.py @@ -1,24 +1,26 @@ # flake8: noqa T001 -import numpy as np import timeit -from openfisca_france import CountryTaxBenefitSystem -from openfisca_core.model_api import * # noqa analysis:ignore +import numpy +from openfisca_france import CountryTaxBenefitSystem tbs = CountryTaxBenefitSystem() N = 200000 -al_plaf_acc = tbs.get_parameters_at_instant('2015-01-01').prestations.al_plaf_acc -zone_apl = np.random.choice([1, 2, 3], N) -al_nb_pac = np.random.choice(6, N) -couple = np.random.choice([True, False], N) -formatted_zone = concat('plafond_pour_accession_a_la_propriete_zone_', zone_apl) # zone_apl returns 1, 2 or 3 but the parameters have a long name +al_plaf_acc = tbs.get_parameters_at_instant("2015-01-01").prestations.al_plaf_acc +zone_apl = numpy.random.choice([1, 2, 3], N) +al_nb_pac = numpy.random.choice(6, N) +couple = numpy.random.choice([True, False], N) +formatted_zone = concat( + "plafond_pour_accession_a_la_propriete_zone_", + zone_apl, +) # zone_apl returns 1, 2 or 3 but the parameters have a long name def formula_with(): plafonds = al_plaf_acc[formatted_zone] - result = ( + return ( plafonds.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + plafonds.menage_seul * couple * (al_nb_pac == 0) + plafonds.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) @@ -26,10 +28,10 @@ def formula_with(): + plafonds.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + plafonds.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + plafonds.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + plafonds.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) - ) - - return result + + plafonds.menage_ou_isole_par_enfant_en_plus + * (al_nb_pac > 5) + * (al_nb_pac - 5) + ) def formula_without(): @@ -37,41 +39,51 @@ def formula_without(): z2 = al_plaf_acc.plafond_pour_accession_a_la_propriete_zone_2 z3 = al_plaf_acc.plafond_pour_accession_a_la_propriete_zone_3 - return (zone_apl == 1) * ( - z1.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) - + z1.menage_seul * couple * (al_nb_pac == 0) - + z1.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) - + z1.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) - + z1.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) - + z1.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) - + z1.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + z1.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) - ) + (zone_apl == 2) * ( - z2.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) - + z2.menage_seul * couple * (al_nb_pac == 0) - + z2.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) - + z2.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) - + z2.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) - + z2.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) - + z2.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + z2.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) - ) + (zone_apl == 3) * ( - z3.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) - + z3.menage_seul * couple * (al_nb_pac == 0) - + z3.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) - + z3.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) - + z3.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) - + z3.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) - + z3.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) - + z3.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) + return ( + (zone_apl == 1) + * ( + z1.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + + z1.menage_seul * couple * (al_nb_pac == 0) + + z1.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) + + z1.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) + + z1.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + + z1.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + + z1.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) + + z1.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) ) + + (zone_apl == 2) + * ( + z2.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + + z2.menage_seul * couple * (al_nb_pac == 0) + + z2.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) + + z2.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) + + z2.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + + z2.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + + z2.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) + + z2.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) + ) + + (zone_apl == 3) + * ( + z3.personne_isolee_sans_enfant * not_(couple) * (al_nb_pac == 0) + + z3.menage_seul * couple * (al_nb_pac == 0) + + z3.menage_ou_isole_avec_1_enfant * (al_nb_pac == 1) + + z3.menage_ou_isole_avec_2_enfants * (al_nb_pac == 2) + + z3.menage_ou_isole_avec_3_enfants * (al_nb_pac == 3) + + z3.menage_ou_isole_avec_4_enfants * (al_nb_pac == 4) + + z3.menage_ou_isole_avec_5_enfants * (al_nb_pac >= 5) + + z3.menage_ou_isole_par_enfant_en_plus * (al_nb_pac > 5) * (al_nb_pac - 5) + ) + ) -if __name__ == '__main__': - - time_with = timeit.timeit('formula_with()', setup = "from __main__ import formula_with", number = 50) - time_without = timeit.timeit('formula_without()', setup = "from __main__ import formula_without", number = 50) - - print("Computing with dynamic legislation computing took {}".format(time_with)) - print("Computing without dynamic legislation computing took {}".format(time_without)) - print("Ratio: {}".format(time_with / time_without)) +if __name__ == "__main__": + time_with = timeit.timeit( + "formula_with()", + setup="from __main__ import formula_with", + number=50, + ) + time_without = timeit.timeit( + "formula_without()", + setup="from __main__ import formula_without", + number=50, + ) diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py index 6e8f672988..38538d644a 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_country_template.py @@ -1,34 +1,30 @@ -# -*- coding: utf-8 -*- - -''' xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_country_template.py : Parse XML parameter files for Country-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_country_template.py output_dir` or just (output is written in a directory called `yaml_parameters`): `python xml_to_yaml_country_template.py` -''' -import sys +""" + import os +import sys + +from openfisca_country_template import COUNTRY_DIR, CountryTaxBenefitSystem -from openfisca_country_template import CountryTaxBenefitSystem, COUNTRY_DIR from . import xml_to_yaml tax_benefit_system = CountryTaxBenefitSystem() -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = 'yaml_parameters' +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" -param_dir = os.path.join(COUNTRY_DIR, 'parameters') +param_dir = os.path.join(COUNTRY_DIR, "parameters") param_files = [ - 'benefits.xml', - 'general.xml', - 'taxes.xml', - ] + "benefits.xml", + "general.xml", + "taxes.xml", +] legislation_xml_info_list = [ - (os.path.join(param_dir, param_file), []) - for param_file in param_files - ] + (os.path.join(param_dir, param_file), []) for param_file in param_files +] xml_to_yaml.write_parameters(legislation_xml_info_list, target_path) diff --git a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py index 91144ed6a0..0b57c19016 100644 --- a/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py +++ b/openfisca_core/scripts/migrations/v16_2_to_v17/xml_to_yaml_extension_template.py @@ -1,31 +1,26 @@ -# -*- coding: utf-8 -*- - -''' xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. +"""xml_to_yaml_extension_template.py : Parse XML parameter files for Extension-Template and convert them to YAML files. Comments are NOT transformed. Usage : `python xml_to_yaml_extension_template.py output_dir` or just (output is written in a directory called `yaml_parameters`): `python xml_to_yaml_extension_template.py` -''' +""" -import sys import os +import sys -from . import xml_to_yaml import openfisca_extension_template -if len(sys.argv) > 1: - target_path = sys.argv[1] -else: - target_path = 'yaml_parameters' +from . import xml_to_yaml + +target_path = sys.argv[1] if len(sys.argv) > 1 else "yaml_parameters" param_dir = os.path.dirname(openfisca_extension_template.__file__) param_files = [ - 'parameters.xml', - ] + "parameters.xml", +] legislation_xml_info_list = [ - (os.path.join(param_dir, param_file), []) - for param_file in param_files - ] + (os.path.join(param_dir, param_file), []) for param_file in param_files +] xml_to_yaml.write_parameters(legislation_xml_info_list, target_path) diff --git a/openfisca_core/scripts/migrations/v24_to_25.py b/openfisca_core/scripts/migrations/v24_to_25.py index 853c4e9a94..08bbeddc3b 100644 --- a/openfisca_core/scripts/migrations/v24_to_25.py +++ b/openfisca_core/scripts/migrations/v24_to_25.py @@ -1,37 +1,52 @@ -# -*- coding: utf-8 -*- # flake8: noqa T001 import argparse -import os import glob +import os +from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedSeq -from openfisca_core.scripts import add_tax_benefit_system_arguments, build_tax_benefit_system +from openfisca_core.scripts import ( + add_tax_benefit_system_arguments, + build_tax_benefit_system, +) -from ruamel.yaml import YAML yaml = YAML() yaml.default_flow_style = False yaml.width = 4096 -TEST_METADATA = {'period', 'name', 'reforms', 'only_variables', 'ignore_variables', 'absolute_error_margin', 'relative_error_margin', 'description', 'keywords'} +TEST_METADATA = { + "period", + "name", + "reforms", + "only_variables", + "ignore_variables", + "absolute_error_margin", + "relative_error_margin", + "description", + "keywords", +} def build_parser(): parser = argparse.ArgumentParser() - parser.add_argument('path', help = "paths (files or directories) of tests to execute", nargs = '+') - parser = add_tax_benefit_system_arguments(parser) + parser.add_argument( + "path", + help="paths (files or directories) of tests to execute", + nargs="+", + ) + return add_tax_benefit_system_arguments(parser) - return parser - -class Migrator(object): - - def __init__(self, tax_benefit_system): +class Migrator: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system - self.entities_by_plural = {entity.plural: entity for entity in self.tax_benefit_system.entities} + self.entities_by_plural = { + entity.plural: entity for entity in self.tax_benefit_system.entities + } - def migrate(self, path): + def migrate(self, path) -> None: if isinstance(path, list): for item in path: self.migrate(item) @@ -49,8 +64,6 @@ def migrate(self, path): return - print('Migrating {}.'.format(path)) - with open(path) as yaml_file: tests = yaml.safe_load(yaml_file) if isinstance(tests, CommentedSeq): @@ -58,23 +71,23 @@ def migrate(self, path): else: migrated_tests = self.convert_test(tests) - with open(path, 'w') as yaml_file: + with open(path, "w") as yaml_file: yaml.dump(migrated_tests, yaml_file) def convert_test(self, test): - if test.get('output'): + if test.get("output"): # This test is already converted, ignoring it return test result = {} - outputs = test.pop('output_variables') - inputs = test.pop('input_variables', {}) + outputs = test.pop("output_variables") + inputs = test.pop("input_variables", {}) for key, value in test.items(): if key in TEST_METADATA: result[key] = value else: inputs[key] = value - result['input'] = self.convert_inputs(inputs) - result['output'] = outputs + result["input"] = self.convert_inputs(inputs) + result["output"] = outputs return result def convert_inputs(self, inputs): @@ -91,15 +104,15 @@ def convert_inputs(self, inputs): continue results[entity_plural] = self.convert_entities(entity, entities_description) - results = self.generate_missing_entities(results) - - return results + return self.generate_missing_entities(results) def convert_entities(self, entity, entities_description): return { - entity_description.get('id', "{}_{}".format(entity.key, index)): remove_id(entity_description) + entity_description.get("id", f"{entity.key}_{index}"): remove_id( + entity_description, + ) for index, entity_description in enumerate(entities_description) - } + } def generate_missing_entities(self, inputs): for entity in self.tax_benefit_system.entities: @@ -108,29 +121,33 @@ def generate_missing_entities(self, inputs): persons = inputs[self.tax_benefit_system.person_entity.plural] if len(persons) == 1: person_id = next(iter(persons)) - inputs[entity.key] = {entity.roles[0].plural or entity.roles[0].key: [person_id]} + inputs[entity.key] = { + entity.roles[0].plural or entity.roles[0].key: [person_id], + } else: inputs[entity.plural] = { - '{}_{}'.format(entity.key, index): {entity.roles[0].plural or entity.roles[0].key: [person_id]} - for index, person_id in enumerate(persons.keys()) + f"{entity.key}_{index}": { + entity.roles[0].plural or entity.roles[0].key: [person_id], } + for index, person_id in enumerate(persons.keys()) + } return inputs def remove_id(input_dict): - return { - key: value - for (key, value) in input_dict.items() - if key != "id" - } + return {key: value for (key, value) in input_dict.items() if key != "id"} -def main(): +def main() -> None: parser = build_parser() args = parser.parse_args() paths = [os.path.abspath(path) for path in args.path] - tax_benefit_system = build_tax_benefit_system(args.country_package, args.extensions, args.reforms) + tax_benefit_system = build_tax_benefit_system( + args.country_package, + args.extensions, + args.reforms, + ) Migrator(tax_benefit_system).migrate(paths) diff --git a/openfisca_core/scripts/openfisca_command.py b/openfisca_core/scripts/openfisca_command.py index 786b73b35f..d82e0aef61 100644 --- a/openfisca_core/scripts/openfisca_command.py +++ b/openfisca_core/scripts/openfisca_command.py @@ -1,8 +1,9 @@ import argparse -import warnings import sys +import warnings from openfisca_core.scripts import add_tax_benefit_system_arguments + """ Define the `openfisca` command line interface. """ @@ -11,62 +12,157 @@ def get_parser(): parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(help = 'Available commands', dest = 'command') - subparsers.required = True # Can be added as an argument of add_subparsers in Python 3 + subparsers = parser.add_subparsers(help="Available commands", dest="command") + subparsers.required = ( + True # Can be added as an argument of add_subparsers in Python 3 + ) def build_serve_parser(parser): # Define OpenFisca modules configuration parser = add_tax_benefit_system_arguments(parser) # Define server configuration - parser.add_argument('-p', '--port', action = 'store', help = "port to serve on (use --bind to specify host and port)", type = int) - parser.add_argument('--tracker-url', action = 'store', help = "tracking service url", type = str) - parser.add_argument('--tracker-idsite', action = 'store', help = "tracking service id site", type = int) - parser.add_argument('--tracker-token', action = 'store', help = "tracking service authentication token", type = str) - parser.add_argument('--welcome-message', action = 'store', help = "welcome message users will get when visiting the API root", type = str) - parser.add_argument('-f', '--configuration-file', action = 'store', help = "configuration file", type = str) + parser.add_argument( + "-p", + "--port", + action="store", + help="port to serve on (use --bind to specify host and port)", + type=int, + ) + parser.add_argument( + "--tracker-url", + action="store", + help="tracking service url", + type=str, + ) + parser.add_argument( + "--tracker-idsite", + action="store", + help="tracking service id site", + type=int, + ) + parser.add_argument( + "--tracker-token", + action="store", + help="tracking service authentication token", + type=str, + ) + parser.add_argument( + "--welcome-message", + action="store", + help="welcome message users will get when visiting the API root", + type=str, + ) + parser.add_argument( + "-f", + "--configuration-file", + action="store", + help="configuration file", + type=str, + ) return parser - parser_serve = subparsers.add_parser('serve', help = 'Run the OpenFisca Web API') + parser_serve = subparsers.add_parser("serve", help="Run the OpenFisca Web API") parser_serve = build_serve_parser(parser_serve) def build_test_parser(parser): - parser.add_argument('path', help = "paths (files or directories) of tests to execute", nargs = '+') + parser.add_argument( + "path", + help="paths (files or directories) of tests to execute", + nargs="+", + ) parser = add_tax_benefit_system_arguments(parser) - parser.add_argument('-n', '--name_filter', default = None, help = "partial name of tests to execute. Only tests with the given name_filter in their name, file name, or keywords will be run.") - parser.add_argument('-p', '--pdb', action = 'store_true', default = False, help = "drop into debugger on failures or errors") - parser.add_argument('--performance-graph', '--performance', action = 'store_true', default = False, help = "output a performance graph in a 'performance_graph.html' file") - parser.add_argument('--performance-tables', action = 'store_true', default = False, help = "output performance CSV tables") - parser.add_argument('-v', '--verbose', action = 'store_true', default = False, help = "increase output verbosity") - parser.add_argument('-o', '--only-variables', nargs = '*', default = None, help = "variables to test. If specified, only test the given variables.") - parser.add_argument('-i', '--ignore-variables', nargs = '*', default = None, help = "variables to ignore. If specified, do not test the given variables.") + parser.add_argument( + "-n", + "--name_filter", + default=None, + help="partial name of tests to execute. Only tests with the given name_filter in their name, file name, or keywords will be run.", + ) + parser.add_argument( + "-p", + "--pdb", + action="store_true", + default=False, + help="drop into debugger on failures or errors", + ) + parser.add_argument( + "--performance-graph", + "--performance", + action="store_true", + default=False, + help="output a performance graph in a 'performance_graph.html' file", + ) + parser.add_argument( + "--performance-tables", + action="store_true", + default=False, + help="output performance CSV tables", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="increase output verbosity. If specified, output the entire calculation trace.", + ) + parser.add_argument( + "-a", + "--aggregate", + action="store_true", + default=False, + help="increase output verbosity to aggregate. If specified, output the avg, max, and min values of the calculation trace. This flag has no effect without --verbose.", + ) + parser.add_argument( + "-d", + "--max-depth", + type=int, + default=None, + help="set maximal verbosity depth. If specified, output the calculation trace up to the provided depth. This flag has no effect without --verbose.", + ) + parser.add_argument( + "-o", + "--only-variables", + nargs="*", + default=None, + help="variables to test. If specified, only test the given variables.", + ) + parser.add_argument( + "-i", + "--ignore-variables", + nargs="*", + default=None, + help="variables to ignore. If specified, do not test the given variables.", + ) return parser - parser_test = subparsers.add_parser('test', help = 'Run OpenFisca YAML tests') + parser_test = subparsers.add_parser("test", help="Run OpenFisca YAML tests") parser_test = build_test_parser(parser_test) return parser def main(): - if sys.argv[0].endswith('openfisca-run-test'): - sys.argv[0:1] = ['openfisca', 'test'] + if sys.argv[0].endswith("openfisca-run-test"): + sys.argv[0:1] = ["openfisca", "test"] message = "The 'openfisca-run-test' command has been deprecated in favor of 'openfisca test' since version 25.0, and will be removed in the future." - warnings.warn(message, Warning) + warnings.warn(message, Warning, stacklevel=2) parser = get_parser() args, _ = parser.parse_known_args() - if args.command == 'serve': + if args.command == "serve": from openfisca_web_api.scripts.serve import main + return sys.exit(main(parser)) - if args.command == 'test': + if args.command == "test": from openfisca_core.scripts.run_test import main + return sys.exit(main(parser)) + return None -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/openfisca_core/scripts/remove_fuzzy.py b/openfisca_core/scripts/remove_fuzzy.py index 669af2d01b..a4827aef39 100755 --- a/openfisca_core/scripts/remove_fuzzy.py +++ b/openfisca_core/scripts/remove_fuzzy.py @@ -1,28 +1,26 @@ # remove_fuzzy.py : Remove the fuzzy attribute in xml files and add END tags. # See https://github.com/openfisca/openfisca-core/issues/437 -import re import datetime +import re import sys -import numpy as np -assert(len(sys.argv) == 2) +import numpy + +assert len(sys.argv) == 2 filename = sys.argv[1] -with open(filename, 'r') as f: +with open(filename) as f: lines = f.readlines() # Remove fuzzy -lines_2 = [ - line.replace(' fuzzy="true"', '') - for line in lines - ] +lines_2 = [line.replace(' fuzzy="true"', "") for line in lines] -regex_indent = r'^(\s*)\n$' -bool_code = [ - bool(re.search(regex_code, line)) - for line in lines_5 - ] +bool_code = [bool(re.search(regex_code, line)) for line in lines_5] -bool_code_end = [ - bool(re.search(regex_code_end, line)) - for line in lines_5 - ] +bool_code_end = [bool(re.search(regex_code_end, line)) for line in lines_5] list_value = [] for line in lines_5: @@ -227,19 +194,19 @@ to_remove = [] for i in range(len(lines_5) - 1): - if (list_value[i] is not None) and (list_value[i + 1] is not None) and (list_value[i] == list_value[i + 1]): + if ( + (list_value[i] is not None) + and (list_value[i + 1] is not None) + and (list_value[i] == list_value[i + 1]) + ): to_remove.append(i) to_remove_set = set(to_remove) -lines_6 = [ - line - for j, line in enumerate(lines_5) - if j not in to_remove_set - ] +lines_6 = [line for j, line in enumerate(lines_5) if j not in to_remove_set] # Write -with open(filename, 'w') as f: +with open(filename, "w") as f: for line in lines_6: f.write(line) diff --git a/openfisca_core/scripts/run_test.py b/openfisca_core/scripts/run_test.py index 77c4140899..458dc7e50e 100644 --- a/openfisca_core/scripts/run_test.py +++ b/openfisca_core/scripts/run_test.py @@ -1,28 +1,35 @@ -# -*- coding: utf-8 -*- - import logging -import sys import os +import sys -from openfisca_core.tools.test_runner import run_tests from openfisca_core.scripts import build_tax_benefit_system +from openfisca_core.tools.test_runner import run_tests -def main(parser): +def main(parser) -> None: args = parser.parse_args() - logging.basicConfig(level = logging.DEBUG if args.verbose else logging.WARNING, stream = sys.stdout) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + stream=sys.stdout, + ) - tax_benefit_system = build_tax_benefit_system(args.country_package, args.extensions, args.reforms) + tax_benefit_system = build_tax_benefit_system( + args.country_package, + args.extensions, + args.reforms, + ) options = { - 'pdb': args.pdb, - 'performance_graph': args.performance_graph, - 'performance_tables': args.performance_tables, - 'verbose': args.verbose, - 'name_filter': args.name_filter, - 'only_variables': args.only_variables, - 'ignore_variables': args.ignore_variables, - } + "pdb": args.pdb, + "performance_graph": args.performance_graph, + "performance_tables": args.performance_tables, + "verbose": args.verbose, + "aggregate": args.aggregate, + "max_depth": args.max_depth, + "name_filter": args.name_filter, + "only_variables": args.only_variables, + "ignore_variables": args.ignore_variables, + } paths = [os.path.abspath(path) for path in args.path] sys.exit(run_tests(tax_benefit_system, paths, options)) diff --git a/openfisca_core/scripts/simulation_generator.py b/openfisca_core/scripts/simulation_generator.py index 489f42356f..eca2fa30d1 100644 --- a/openfisca_core/scripts/simulation_generator.py +++ b/openfisca_core/scripts/simulation_generator.py @@ -1,30 +1,32 @@ -import numpy as np - import random + +import numpy + from openfisca_core.simulations import Simulation def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): - """ - Generate a simulation containing nb_persons persons spread in nb_groups groups. + """Generate a simulation containing nb_persons persons spread in nb_groups groups. - Example: + Example: + >>> from openfisca_core.scripts.simulation_generator import make_simulation + >>> from openfisca_france import CountryTaxBenefitSystem + >>> tbs = CountryTaxBenefitSystem() + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> simulation.calculate("revenu_disponible", 2017) - >>> from openfisca_core.scripts.simulation_generator import make_simulation - >>> from openfisca_france import CountryTaxBenefitSystem - >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> simulation.calculate('revenu_disponible', 2017) """ - simulation = Simulation(tax_benefit_system = tax_benefit_system, **kwargs) - simulation.persons.ids = np.arange(nb_persons) + simulation = Simulation(tax_benefit_system=tax_benefit_system, **kwargs) + simulation.persons.ids = numpy.arange(nb_persons) simulation.persons.count = nb_persons - adults = [0] + sorted(random.sample(range(1, nb_persons), nb_groups - 1)) + adults = [0, *sorted(random.sample(range(1, nb_persons), nb_groups - 1))] - members_entity_id = np.empty(nb_persons, dtype = int) + members_entity_id = numpy.empty(nb_persons, dtype=int) # A legacy role is an index that every person within an entity has. For instance, the 'demandeur' has legacy role 0, the 'conjoint' 1, the first 'child' 2, the second 3, etc. - members_legacy_role = np.empty(nb_persons, dtype = int) + members_legacy_role = numpy.empty(nb_persons, dtype=int) id_group = -1 for id_person in range(nb_persons): @@ -40,27 +42,49 @@ def make_simulation(tax_benefit_system, nb_persons, nb_groups, **kwargs): if not entity.is_person: entity.members_entity_id = members_entity_id entity.count = nb_groups - entity.members_role = np.where(members_legacy_role == 0, entity.flattened_roles[0], entity.flattened_roles[-1]) + entity.members_role = numpy.where( + members_legacy_role == 0, + entity.flattened_roles[0], + entity.flattened_roles[-1], + ) return simulation -def randomly_init_variable(simulation, variable_name, period, max_value, condition = None): - """ - Initialise a variable with random values (from 0 to max_value) for the given period. - If a condition vector is provided, only set the value of persons or groups for which condition is True. +def randomly_init_variable( + simulation, + variable_name: str, + period, + max_value, + condition=None, +) -> None: + """Initialise a variable with random values (from 0 to max_value) for the given period. + If a condition vector is provided, only set the value of persons or groups for which condition is True. - Example: + Example: + >>> from openfisca_core.scripts.simulation_generator import ( + ... make_simulation, + ... randomly_init_variable, + ... ) + >>> from openfisca_france import CountryTaxBenefitSystem + >>> tbs = CountryTaxBenefitSystem() + >>> simulation = make_simulation( + ... tbs, 400, 100 + ... ) # Create a simulation with 400 persons, spread among 100 families + >>> randomly_init_variable( + ... simulation, + ... "salaire_net", + ... 2017, + ... max_value=50000, + ... condition=simulation.persons.has_role(simulation.famille.DEMANDEUR), + ... ) # Randomly set a salaire_net for all persons between 0 and 50000? + >>> simulation.calculate("revenu_disponible", 2017) - >>> from openfisca_core.scripts.simulation_generator import make_simulation, randomly_init_variable - >>> from openfisca_france import CountryTaxBenefitSystem - >>> tbs = CountryTaxBenefitSystem() - >>> simulation = make_simulation(tbs, 400, 100) # Create a simulation with 400 persons, spread among 100 families - >>> randomly_init_variable(simulation, 'salaire_net', 2017, max_value = 50000, condition = simulation.persons.has_role(simulation.famille.DEMANDEUR)) # Randomly set a salaire_net for all persons between 0 and 50000? - >>> simulation.calculate('revenu_disponible', 2017) - """ + """ if condition is None: condition = True variable = simulation.tax_benefit_system.get_variable(variable_name) population = simulation.get_variable_population(variable_name) - value = (np.random.rand(population.count) * max_value * condition).astype(variable.dtype) + value = (numpy.random.rand(population.count) * max_value * condition).astype( + variable.dtype, + ) simulation.set_input(variable_name, period, value) diff --git a/openfisca_core/simulation_builder.py b/openfisca_core/simulation_builder.py index 57c7765ebe..189bba3bcb 100644 --- a/openfisca_core/simulation_builder.py +++ b/openfisca_core/simulation_builder.py @@ -6,11 +6,11 @@ # The following are transitional imports to ensure non-breaking changes. # Could be deprecated in the next major release. -from openfisca_core.simulations import ( # noqa: F401 +from openfisca_core.simulations import ( # noqa: F401 Simulation, SimulationBuilder, calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax, - ) +) diff --git a/openfisca_core/simulations/__init__.py b/openfisca_core/simulations/__init__.py index 5b02dc1a22..9ab10f81a7 100644 --- a/openfisca_core/simulations/__init__.py +++ b/openfisca_core/simulations/__init__.py @@ -21,8 +21,25 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from openfisca_core.errors import CycleError, NaNCreationError, SpiralError # noqa: F401 +from openfisca_core.errors import CycleError, NaNCreationError, SpiralError -from .helpers import calculate_output_add, calculate_output_divide, check_type, transform_to_strict_syntax # noqa: F401 -from .simulation import Simulation # noqa: F401 -from .simulation_builder import SimulationBuilder # noqa: F401 +from .helpers import ( + calculate_output_add, + calculate_output_divide, + check_type, + transform_to_strict_syntax, +) +from .simulation import Simulation +from .simulation_builder import SimulationBuilder + +__all__ = [ + "CycleError", + "NaNCreationError", + "Simulation", + "SimulationBuilder", + "SpiralError", + "calculate_output_add", + "calculate_output_divide", + "check_type", + "transform_to_strict_syntax", +] diff --git a/openfisca_core/simulations/_build_default_simulation.py b/openfisca_core/simulations/_build_default_simulation.py new file mode 100644 index 0000000000..adc7cf4783 --- /dev/null +++ b/openfisca_core/simulations/_build_default_simulation.py @@ -0,0 +1,159 @@ +"""This module contains the _BuildDefaultSimulation class.""" + +from typing import Union +from typing_extensions import Self + +import numpy + +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem + + +class _BuildDefaultSimulation: + """Build a default simulation. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + count(int): The number of periods. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 1 + >>> builder = ( + ... _BuildDefaultSimulation(tax_benefit_system, count) + ... .add_count() + ... .add_ids() + ... .add_members_entity_id() + ... ) + + >>> builder.count + 1 + + >>> sorted(builder.populations.keys()) + ['dog', 'pack'] + + >>> sorted(builder.simulation.populations.keys()) + ['dog', 'pack'] + + """ + + #: The number of Population. + count: int + + #: The built populations. + populations: dict[str, Union[Population[Entity]]] + + #: The built simulation. + simulation: Simulation + + def __init__(self, tax_benefit_system: TaxBenefitSystem, count: int) -> None: + self.count = count + self.populations = tax_benefit_system.instantiate_entities() + self.simulation = Simulation(tax_benefit_system, self.populations) + + def add_count(self) -> Self: + """Add the number of Population to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_count() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].count + 2 + + >>> builder.populations["pack"].count + 2 + + """ + for population in self.populations.values(): + population.count = self.count + + return self + + def add_ids(self) -> Self: + """Add the populations ids to the simulation. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_ids() + <..._BuildDefaultSimulation object at ...> + + >>> builder.populations["dog"].ids + array([0, 1]) + + >>> builder.populations["pack"].ids + array([0, 1]) + + """ + for population in self.populations.values(): + population.ids = numpy.array(range(self.count)) + + return self + + def add_members_entity_id(self) -> Self: + """Add ??? + + Each SingleEntity has its own GroupEntity. + + Returns: + _BuildDefaultSimulation: The builder. + + Examples: + >>> from openfisca_core import entities, taxbenefitsystems + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> count = 2 + >>> builder = _BuildDefaultSimulation(tax_benefit_system, count) + + >>> builder.add_members_entity_id() + <..._BuildDefaultSimulation object at ...> + + >>> population = builder.populations["pack"] + + >>> hasattr(population, "members_entity_id") + True + + >>> population.members_entity_id + array([0, 1]) + + """ + for population in self.populations.values(): + if hasattr(population, "members_entity_id"): + population.members_entity_id = numpy.array(range(self.count)) + + return self diff --git a/openfisca_core/simulations/_build_from_variables.py b/openfisca_core/simulations/_build_from_variables.py new file mode 100644 index 0000000000..20f49ce113 --- /dev/null +++ b/openfisca_core/simulations/_build_from_variables.py @@ -0,0 +1,230 @@ +"""This module contains the _BuildFromVariables class.""" + +from __future__ import annotations + +from typing_extensions import Self + +from openfisca_core import errors + +from ._build_default_simulation import _BuildDefaultSimulation +from ._type_guards import is_variable_dated +from .simulation import Simulation +from .typing import Entity, Population, TaxBenefitSystem, Variables + + +class _BuildFromVariables: + """Build a simulation from variables. + + Args: + tax_benefit_system(TaxBenefitSystem): The tax-benefit system. + params(Variables): The simulation parameters. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = ( + ... _BuildFromVariables(tax_benefit_system, variables, period) + ... .add_dated_values() + ... .add_undated_values() + ... ) + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + + #: The number of Population. + count: int + + #: The Simulation's default period. + default_period: str | None + + #: The built populations. + populations: dict[str, Population[Entity]] + + #: The built simulation. + simulation: Simulation + + #: The simulation parameters. + variables: Variables + + def __init__( + self, + tax_benefit_system: TaxBenefitSystem, + params: Variables, + default_period: str | None = None, + ) -> None: + self.count = _person_count(params) + + default_builder = ( + _BuildDefaultSimulation(tax_benefit_system, self.count) + .add_count() + .add_ids() + .add_members_entity_id() + ) + + self.variables = params + self.simulation = default_builder.simulation + self.populations = default_builder.populations + self.default_period = default_period + + def add_dated_values(self) -> Self: + """Add the dated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_dated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + array([10000], dtype=int32) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + + """ + for variable, value in self.variables.items(): + if is_variable_dated(dated_variable := value): + for period, dated_value in dated_variable.items(): + self.simulation.set_input(variable, period, dated_value) + + return self + + def add_undated_values(self) -> Self: + """Add the undated input values to the Simulation. + + Returns: + _BuildFromVariables: The builder. + + Raises: + SituationParsingError: If there is not a default period set. + + Examples: + >>> from openfisca_core import entities, periods, taxbenefitsystems, variables + + >>> role = {"key": "stray", "plural": "stray", "label": "", "doc": ""} + >>> single_entity = entities.Entity("dog", "dogs", "", "") + >>> group_entity = entities.GroupEntity("pack", "packs", "", "", [role]) + + >>> class salary(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = single_entity + ... value_type = int + + >>> class taxes(variables.Variable): + ... definition_period = periods.DateUnit.MONTH + ... entity = group_entity + ... value_type = int + + >>> test_entities = [single_entity, group_entity] + >>> tax_benefit_system = taxbenefitsystems.TaxBenefitSystem(test_entities) + >>> tax_benefit_system.load_variable(salary) + <...salary object at ...> + >>> tax_benefit_system.load_variable(taxes) + <...taxes object at ...> + >>> period = "2023-12" + >>> variables = {"salary": {period: 10000}, "taxes": 5000} + >>> builder = _BuildFromVariables(tax_benefit_system, variables) + >>> builder.add_undated_values() + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + >>> builder.default_period = period + >>> builder.add_undated_values() + <..._BuildFromVariables object at ...> + + >>> dogs = builder.populations["dog"].get_holder("salary") + >>> dogs.get_array(period) + + >>> pack = builder.populations["pack"].get_holder("taxes") + >>> pack.get_array(period) + array([5000], dtype=int32) + + """ + for variable, value in self.variables.items(): + if not is_variable_dated(undated_value := value): + if (period := self.default_period) is None: + message = ( + "Can't deal with type: expected object. Input " + "variables should be set for specific periods. For " + "instance: " + " {'salary': {'2017-01': 2000, '2017-02': 2500}}" + " {'birth_date': {'ETERNITY': '1980-01-01'}}" + ) + + raise errors.SituationParsingError([variable], message) + + self.simulation.set_input(variable, period, undated_value) + + return self + + +def _person_count(params: Variables) -> int: + try: + first_value = next(iter(params.values())) + + if isinstance(first_value, dict): + first_value = next(iter(first_value.values())) + + if isinstance(first_value, str): + return 1 + + return len(first_value) + + except Exception: + return 1 diff --git a/openfisca_core/simulations/_type_guards.py b/openfisca_core/simulations/_type_guards.py new file mode 100644 index 0000000000..990248213d --- /dev/null +++ b/openfisca_core/simulations/_type_guards.py @@ -0,0 +1,298 @@ +"""Type guards to help type narrowing simulation parameters.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing_extensions import TypeGuard + +from .typing import ( + Axes, + DatedVariable, + FullySpecifiedEntities, + ImplicitGroupEntities, + Params, + UndatedVariable, + Variables, +) + + +def are_entities_fully_specified( + params: Params, + items: Iterable[str], +) -> TypeGuard[FullySpecifiedEntities]: + """Check if the params contain fully specified entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the params contain fully specified entities. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": {"Alicia": {"salary": {"2018-11": 0}}, "Javier": {}, "Tom": {}}, + ... } + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_fully_specified(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_fully_specified(params, entities) + False + + >>> params = {} + + >>> are_entities_fully_specified(params, entities) + False + + """ + if not params: + return False + + return all(key in items for key in params if key != "axes") + + +def are_entities_short_form( + params: Params, + items: Iterable[str], +) -> TypeGuard[ImplicitGroupEntities]: + """Check if the params contain short form entities. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of entities in singular form. + + Returns: + bool: True if the params contain short form entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = { + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"household": {"parents": "Javier"}} + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {"salary": 12000} + + >>> are_entities_short_form(params, entities) + False + + >>> params = {} + + >>> are_entities_short_form(params, entities) + False + + """ + return bool(set(params).intersection(items)) + + +def are_entities_specified( + params: Params, + items: Iterable[str], +) -> TypeGuard[Variables]: + """Check if the params contains entities at all. + + Args: + params(Params): Simulation parameters. + items(Iterable[str]): List of variables. + + Returns: + bool: True if the params does not contain variables at the root level. + + Examples: + >>> variables = {"salary"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_specified(params, variables) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": 2000}}}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"household": {"parents": ["Javier"]}} + + >>> are_entities_specified(params, variables) + True + + >>> params = {"salary": {"2016-10": [12000, 13000]}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": [12000, 13000]} + + >>> are_entities_specified(params, variables) + False + + >>> params = {"salary": 12000} + + >>> are_entities_specified(params, variables) + False + + >>> params = {} + + >>> are_entities_specified(params, variables) + False + + """ + if not params: + return False + + return not any(key in items for key in params) + + +def has_axes(params: Params) -> TypeGuard[Axes]: + """Check if the params contains axes. + + Args: + params(Params): Simulation parameters. + + Returns: + bool: True if the params contain axes. + + Examples: + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> has_axes(params) + True + + >>> params = {"persons": {"Javier": {"salary": {"2018-11": [2000, 3000]}}}} + + >>> has_axes(params) + False + + """ + return params.get("axes", None) is not None + + +def is_variable_dated( + variable: DatedVariable | UndatedVariable, +) -> TypeGuard[DatedVariable]: + """Check if the variable is dated. + + Args: + variable(DatedVariable | UndatedVariable): A variable. + + Returns: + bool: True if the variable is dated. + + Examples: + >>> variable = {"2018-11": [2000, 3000]} + + >>> is_variable_dated(variable) + True + + >>> variable = {"2018-11": 2000} + + >>> is_variable_dated(variable) + True + + >>> variable = 2000 + + >>> is_variable_dated(variable) + False + + """ + return isinstance(variable, dict) diff --git a/openfisca_core/simulations/helpers.py b/openfisca_core/simulations/helpers.py index 683a4106b9..7929c5beda 100644 --- a/openfisca_core/simulations/helpers.py +++ b/openfisca_core/simulations/helpers.py @@ -1,27 +1,106 @@ -from openfisca_core.errors import SituationParsingError +from collections.abc import Iterable +from openfisca_core import errors -def calculate_output_add(simulation, variable_name, period): +from .typing import ParamsWithoutAxes + + +def calculate_output_add(simulation, variable_name: str, period): return simulation.calculate_add(variable_name, period) -def calculate_output_divide(simulation, variable_name, period): +def calculate_output_divide(simulation, variable_name: str, period): return simulation.calculate_divide(variable_name, period) -def check_type(input, input_type, path = None): +def check_type(input, input_type, path=None) -> None: json_type_map = { dict: "Object", list: "Array", str: "String", - } + } if path is None: path = [] if not isinstance(input, input_type): - raise SituationParsingError(path, - "Invalid type: must be of type '{}'.".format(json_type_map[input_type])) + raise errors.SituationParsingError( + path, + f"Invalid type: must be of type '{json_type_map[input_type]}'.", + ) + + +def check_unexpected_entities( + params: ParamsWithoutAxes, + entities: Iterable[str], +) -> None: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Raises: + SituationParsingError: If there are entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> check_unexpected_entities(params, entities) + + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} + + >>> check_unexpected_entities(params, entities) + Traceback (most recent call last): + openfisca_core.errors.situation_parsing_error.SituationParsingError + + """ + if has_unexpected_entities(params, entities): + unexpected_entities = [entity for entity in params if entity not in entities] + + message = ( + "Some entities in the situation are not defined in the loaded tax " + "and benefit system. " + f"These entities are not found: {', '.join(unexpected_entities)}. " + f"The defined entities are: {', '.join(entities)}." + ) + + raise errors.SituationParsingError([unexpected_entities[0]], message) + + +def has_unexpected_entities(params: ParamsWithoutAxes, entities: Iterable[str]) -> bool: + """Check if the input contains entities that are not in the system. + + Args: + params(ParamsWithoutAxes): Simulation parameters. + entities(Iterable[str]): List of entities in plural form. + + Returns: + bool: True if the input contains entities that are not in the system. + + Examples: + >>> entities = {"persons", "households"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> has_unexpected_entities(params, entities) + False + + >>> params = {"dogs": {"Bart": {"damages": {"2018-11": 2000}}}} + + >>> has_unexpected_entities(params, entities) + True + + """ + return any(entity for entity in params if entity not in entities) def transform_to_strict_syntax(data): @@ -30,16 +109,3 @@ def transform_to_strict_syntax(data): if isinstance(data, list): return [str(item) if isinstance(item, int) else item for item in data] return data - - -def _get_person_count(input_dict): - try: - first_value = next(iter(input_dict.values())) - if isinstance(first_value, dict): - first_value = next(iter(first_value.values())) - if isinstance(first_value, str): - return 1 - - return len(first_value) - except Exception: - return 1 diff --git a/openfisca_core/simulations/simulation.py b/openfisca_core/simulations/simulation.py index 5dd2694292..c32fea22af 100644 --- a/openfisca_core/simulations/simulation.py +++ b/openfisca_core/simulations/simulation.py @@ -1,28 +1,38 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import NamedTuple + +from openfisca_core.types import Population, TaxBenefitSystem, Variable + import tempfile import warnings import numpy -from openfisca_core import commons, periods -from openfisca_core.errors import CycleError, SpiralError -from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.periods import Period -from openfisca_core.tracers import FullTracer, SimpleTracer, TracingParameterNodeAtInstant -from openfisca_core.warnings import TempfileWarning +from openfisca_core import ( + commons, + errors, + indexed_enums, + periods, + tracers, + warnings as core_warnings, +) class Simulation: - """ - Represents a simulation, and handles the calculation logic - """ + """Represents a simulation, and handles the calculation logic.""" + + tax_benefit_system: TaxBenefitSystem + populations: dict[str, Population] + invalidated_caches: set[Cache] def __init__( - self, - tax_benefit_system, - populations - ): - """ - This constructor is reserved for internal use; see :any:`SimulationBuilder`, + self, + tax_benefit_system: TaxBenefitSystem, + populations: Mapping[str, Population], + ) -> None: + """This constructor is reserved for internal use; see :any:`SimulationBuilder`, which is the preferred way to obtain a Simulation initialized with a consistent set of Entities. """ @@ -38,11 +48,11 @@ def __init__( self.debug = False self.trace = False - self.tracer = SimpleTracer() + self.tracer = tracers.SimpleTracer() self.opt_out_cache = False # controls the spirals detection; check for performance impact if > 1 - self.max_spiral_loops = 1 + self.max_spiral_loops: int = 1 self.memory_config = None self._data_storage_dir = None @@ -51,42 +61,45 @@ def trace(self): return self._trace @trace.setter - def trace(self, trace): + def trace(self, trace) -> None: self._trace = trace if trace: - self.tracer = FullTracer() + self.tracer = tracers.FullTracer() else: - self.tracer = SimpleTracer() + self.tracer = tracers.SimpleTracer() - def link_to_entities_instances(self): - for _key, entity_instance in self.populations.items(): + def link_to_entities_instances(self) -> None: + for entity_instance in self.populations.values(): entity_instance.simulation = self - def create_shortcuts(self): - for _key, population in self.populations.items(): + def create_shortcuts(self) -> None: + for population in self.populations.values(): # create shortcut simulation.person and simulation.household (for instance) setattr(self, population.entity.key, population) @property def data_storage_dir(self): - """ - Temporary folder used to store intermediate calculation data in case the memory is saturated - """ + """Temporary folder used to store intermediate calculation data in case the memory is saturated.""" if self._data_storage_dir is None: - self._data_storage_dir = tempfile.mkdtemp(prefix = "openfisca_") + self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_") message = [ - ("Intermediate results will be stored on disk in {} in case of memory overflow.").format(self._data_storage_dir), - "You should remove this directory once you're done with your simulation." - ] - warnings.warn(" ".join(message), TempfileWarning) + ( + f"Intermediate results will be stored on disk in {self._data_storage_dir} in case of memory overflow." + ), + "You should remove this directory once you're done with your simulation.", + ] + warnings.warn( + " ".join(message), + core_warnings.TempfileWarning, + stacklevel=2, + ) return self._data_storage_dir # ----- Calculation methods ----- # - def calculate(self, variable_name, period): + def calculate(self, variable_name: str, period): """Calculate ``variable_name`` for ``period``.""" - - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) self.tracer.record_calculation_start(variable_name, period) @@ -100,15 +113,22 @@ def calculate(self, variable_name, period): self.tracer.record_calculation_end() self.purge_cache_of_invalid_values() - def _calculate(self, variable_name, period: Period): - """ - Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. + def _calculate(self, variable_name: str, period: periods.Period): + """Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists. :returns: A numpy array containing the result of the calculation """ + variable: Variable | None + population = self.get_variable_population(variable_name) holder = population.get_holder(variable_name) - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) self._check_period_consistency(period, variable) @@ -131,75 +151,163 @@ def _calculate(self, variable_name, period: Period): array = self._cast_formula_result(array, variable) holder.put_in_cache(array, period) - except SpiralError: + except errors.SpiralError: array = holder.default_array() return array - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: # We wait for the end of calculate(), signalled by an empty stack, before purging the cache if self.tracer.stack: return - for (_name, _period) in self.invalidated_caches: + for _name, _period in self.invalidated_caches: holder = self.get_holder(_name) holder.delete_arrays(_period) self.invalidated_caches = set() - def calculate_add(self, variable_name, period): - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + def calculate_add(self, variable_name: str, period): + variable: Variable | None + + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) # Check that the requested period matches definition_period - if periods.unit_weight(variable.definition_period) > periods.unit_weight(period.unit): - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( - variable.name, - period, - variable.definition_period - )) - - if variable.definition_period not in [periods.DAY, periods.MONTH, periods.YEAR]: - raise ValueError("Unable to sum constant variable '{}' over period {}: only variables defined daily, monthly, or yearly can be summed over time.".format( - variable.name, - period)) + if periods.unit_weight(variable.definition_period) > periods.unit_weight( + period.unit, + ): + msg = ( + f"Unable to compute variable '{variable.name}' for period " + f"{period}: '{variable.name}' can only be computed for " + f"{variable.definition_period}-long periods. You can use the " + f"DIVIDE option to get an estimate of {variable.name}." + ) + raise ValueError( + msg, + ) + + if variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = ( + f"Unable to ADD constant variable '{variable.name}' over " + f"the period {period}: eternal variables can't be summed " + "over time." + ) + raise ValueError( + msg, + ) return sum( self.calculate(variable_name, sub_period) for sub_period in period.get_subperiods(variable.definition_period) - ) + ) + + def calculate_divide(self, variable_name: str, period): + variable: Variable | None - def calculate_divide(self, variable_name, period): - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) - # Check that the requested period matches definition_period - if variable.definition_period != periods.YEAR: - raise ValueError("Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time.".format( - variable_name, - period)) + if ( + periods.unit_weight(variable.definition_period) + < periods.unit_weight(period.unit) + or period.size > 1 + ): + msg = ( + f"Can't calculate variable '{variable.name}' for period " + f"{period}: '{variable.name}' can only be computed for " + f"{variable.definition_period}-long periods. You can use the " + f"ADD option to get an estimate of {variable.name}." + ) + raise ValueError( + msg, + ) - if period.size != 1: - raise ValueError("DIVIDE option can only be used for a one-year or a one-month requested period") + if variable.definition_period not in ( + periods.DateUnit.isoformat + periods.DateUnit.isocalendar + ): + msg = ( + f"Unable to DIVIDE constant variable '{variable.name}' over " + f"the period {period}: eternal variables can't be divided " + "over time." + ) + raise ValueError( + msg, + ) - if period.unit == periods.MONTH: - computation_period = period.this_year - return self.calculate(variable_name, period = computation_period) / 12. - elif period.unit == periods.YEAR: - return self.calculate(variable_name, period) + if ( + period.unit + not in (periods.DateUnit.isoformat + periods.DateUnit.isocalendar) + or period.size != 1 + ): + msg = ( + f"Unable to DIVIDE constant variable '{variable.name}' over " + f"the period {period}: eternal variables can't be used " + "as a denominator to divide a variable over time." + ) + raise ValueError( + msg, + ) - raise ValueError("Unable to divide the value of '{}' to match period {}.".format( - variable_name, - period)) + if variable.definition_period == periods.DateUnit.YEAR: + calculation_period = period.this_year - def calculate_output(self, variable_name, period): - """ - Calculate the value of a variable using the ``calculate_output`` attribute of the variable. - """ + elif variable.definition_period == periods.DateUnit.MONTH: + calculation_period = period.first_month + + elif variable.definition_period == periods.DateUnit.DAY: + calculation_period = period.first_day + + elif variable.definition_period == periods.DateUnit.WEEK: + calculation_period = period.first_week + + else: + calculation_period = period.first_weekday - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + if period.unit == periods.DateUnit.YEAR: + denominator = calculation_period.size_in_years + + elif period.unit == periods.DateUnit.MONTH: + denominator = calculation_period.size_in_months + + elif period.unit == periods.DateUnit.DAY: + denominator = calculation_period.size_in_days + + elif period.unit == periods.DateUnit.WEEK: + denominator = calculation_period.size_in_weeks + + else: + denominator = calculation_period.size_in_weekdays + + return self.calculate(variable_name, calculation_period) / denominator + + def calculate_output(self, variable_name: str, period): + """Calculate the value of a variable using the ``calculate_output`` attribute of the variable.""" + variable: Variable | None + + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) if variable.calculate_output is None: return self.calculate(variable_name, period) @@ -207,16 +315,13 @@ def calculate_output(self, variable_name, period): return variable.calculate_output(self, variable_name, period) def trace_parameters_at_instant(self, formula_period): - return TracingParameterNodeAtInstant( + return tracers.TracingParameterNodeAtInstant( self.tax_benefit_system.get_parameters_at_instant(formula_period), - self.tracer - ) + self.tracer, + ) def _run_formula(self, variable, population, period): - """ - Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``. - """ - + """Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.""" formula = variable.get_formula(period) if formula is None: return None @@ -233,34 +338,49 @@ def _run_formula(self, variable, population, period): return array - def _check_period_consistency(self, period, variable): - """ - Check that a period matches the variable definition_period - """ - if variable.definition_period == periods.ETERNITY: + def _check_period_consistency(self, period, variable) -> None: + """Check that a period matches the variable definition_period.""" + if variable.definition_period == periods.DateUnit.ETERNITY: return # For variables which values are constant in time, all periods are accepted - if variable.definition_period == periods.MONTH and period.unit != periods.MONTH: - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format( - variable.name, - period - )) + if ( + variable.definition_period == periods.DateUnit.YEAR + and period.unit != periods.DateUnit.YEAR + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {variable.name} by dividing the yearly value by 12, or change the requested period to 'period.this_year'." + raise ValueError( + msg, + ) - if variable.definition_period == periods.YEAR and period.unit != periods.YEAR: - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format( - variable.name, - period - )) + if ( + variable.definition_period == periods.DateUnit.MONTH + and period.unit != periods.DateUnit.MONTH + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole month. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_month'." + raise ValueError( + msg, + ) + + if ( + variable.definition_period == periods.DateUnit.WEEK + and period.unit != periods.DateUnit.WEEK + ): + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole week. You can use the ADD option to sum '{variable.name}' over the requested period, or change the requested period to 'period.first_week'." + raise ValueError( + msg, + ) if period.size != 1: - raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format( - variable.name, - period, - 'month' if variable.definition_period == periods.MONTH else 'year' - )) + msg = f"Unable to compute variable '{variable.name}' for period {period}: '{variable.name}' must be computed for a whole {variable.definition_period}. You can use the ADD option to sum '{variable.name}' over the requested period." + raise ValueError( + msg, + ) def _cast_formula_result(self, value, variable): - if variable.value_type == Enum and not isinstance(value, EnumArray): + if variable.value_type == indexed_enums.Enum and not isinstance( + value, + indexed_enums.EnumArray, + ): return variable.possible_values.encode(value) if not isinstance(value, numpy.ndarray): @@ -274,122 +394,117 @@ def _cast_formula_result(self, value, variable): # ----- Handle circular dependencies in a calculation ----- # - def _check_for_cycle(self, variable: str, period): - """ - Raise an exception in the case of a circular definition, where evaluating a variable for + def _check_for_cycle(self, variable: str, period) -> None: + """Raise an exception in the case of a circular definition, where evaluating a variable for a given period loops around to evaluating the same variable/period pair. Also guards, as a heuristic, against "quasicircles", where the evaluation of a variable at a period involves the same variable at a different period. """ # The last frame is the current calculation, so it should be ignored from cycle detection - previous_periods = [frame['period'] for frame in self.tracer.stack[:-1] if frame['name'] == variable] + previous_periods = [ + frame["period"] + for frame in self.tracer.stack[:-1] + if frame["name"] == variable + ] if period in previous_periods: - raise CycleError("Circular definition detected on formula {}@{}".format(variable, period)) + msg = f"Circular definition detected on formula {variable}@{period}" + raise errors.CycleError( + msg, + ) spiral = len(previous_periods) >= self.max_spiral_loops if spiral: self.invalidate_spiral_variables(variable) - message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable, period, self.tracer.stack) - raise SpiralError(message, variable) + message = f"Quasicircular definition detected on formula {variable}@{period} involving {self.tracer.stack}" + raise errors.SpiralError(message, variable) - def invalidate_cache_entry(self, variable: str, period): - self.invalidated_caches.add((variable, period)) + def invalidate_cache_entry(self, variable: str, period) -> None: + self.invalidated_caches.add(Cache(variable, period)) - def invalidate_spiral_variables(self, variable: str): + def invalidate_spiral_variables(self, variable: str) -> None: # Visit the stack, from the bottom (most recent) up; we know that we'll find # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the # intermediate values computed (to avoid impacting performance) but we mark them # for deletion from the cache once the calculation ends. count = 0 for frame in reversed(self.tracer.stack): - self.invalidate_cache_entry(frame['name'], frame['period']) - if frame['name'] == variable: + self.invalidate_cache_entry(str(frame["name"]), frame["period"]) + if frame["name"] == variable: count += 1 if count > self.max_spiral_loops: break # ----- Methods to access stored values ----- # - def get_array(self, variable_name, period): - """ - Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). + def get_array(self, variable_name: str, period): + """Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated). Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula. """ - if period is not None and not isinstance(period, Period): + if period is not None and not isinstance(period, periods.Period): period = periods.period(period) return self.get_holder(variable_name).get_array(period) - def get_holder(self, variable_name): - """ - Get the :obj:`.Holder` associated with the variable ``variable_name`` for the simulation - """ + def get_holder(self, variable_name: str): + """Get the holder associated with the variable.""" return self.get_variable_population(variable_name).get_holder(variable_name) - def get_memory_usage(self, variables = None): - """ - Get data about the virtual memory usage of the simulation - """ - result = dict( - total_nb_bytes = 0, - by_variable = {} - ) + def get_memory_usage(self, variables=None): + """Get data about the virtual memory usage of the simulation.""" + result = {"total_nb_bytes": 0, "by_variable": {}} for entity in self.populations.values(): - entity_memory_usage = entity.get_memory_usage(variables = variables) - result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes'] - result['by_variable'].update(entity_memory_usage['by_variable']) + entity_memory_usage = entity.get_memory_usage(variables=variables) + result["total_nb_bytes"] += entity_memory_usage["total_nb_bytes"] + result["by_variable"].update(entity_memory_usage["by_variable"]) return result # ----- Misc ----- # - def delete_arrays(self, variable, period = None): - """ - Delete a variable's value for a given period + def delete_arrays(self, variable, period=None) -> None: + """Delete a variable's value for a given period. :param variable: the variable to be set :param period: the period for which the value should be deleted Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_array('age', '2018-05') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_array("age", "2018-05") array([13, 14], dtype=int32) - >>> simulation.delete_arrays('age', '2018-05') - >>> simulation.get_array('age', '2018-04') + >>> simulation.delete_arrays("age", "2018-05") + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.delete_arrays('age') - >>> simulation.get_array('age', '2018-04') is None + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.delete_arrays("age") + >>> simulation.get_array("age", "2018-04") is None True - >>> simulation.get_array('age', '2018-05') is None + >>> simulation.get_array("age", "2018-05") is None True + """ self.get_holder(variable).delete_arrays(period) def get_known_periods(self, variable): - """ - Get a list variable's known period, i.e. the periods where a value has been initialized and + """Get a list variable's known period, i.e. the periods where a value has been initialized and. :param variable: the variable to be set Example: - >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.set_input('age', '2018-05', [13, 14]) - >>> simulation.get_known_periods('age') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.set_input("age", "2018-05", [13, 14]) + >>> simulation.get_known_periods("age") [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))] + """ return self.get_holder(variable).get_known_periods() - def set_input(self, variable_name, period, value): - """ - Set a variable's value for a given period + def set_input(self, variable_name: str, period, value) -> None: + """Set a variable's value for a given period. :param variable: the variable to be set :param value: the input value for the variable @@ -398,41 +513,71 @@ def set_input(self, variable_name, period, value): Example: >>> from openfisca_country_template import CountryTaxBenefitSystem >>> simulation = Simulation(CountryTaxBenefitSystem()) - >>> simulation.set_input('age', '2018-04', [12, 14]) - >>> simulation.get_array('age', '2018-04') + >>> simulation.set_input("age", "2018-04", [12, 14]) + >>> simulation.get_array("age", "2018-04") array([12, 14], dtype=int32) If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation `_. + """ - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) + variable: Variable | None + + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) + period = periods.period(period) - if ((variable.end is not None) and (period.start.date > variable.end)): + if (variable.end is not None) and (period.start.date > variable.end): return self.get_holder(variable_name).set_input(period, value) - def get_variable_population(self, variable_name): - variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True) - return self.populations[variable.entity.key] + def get_variable_population(self, variable_name: str) -> Population: + variable: Variable | None - def get_population(self, plural = None): - return next((population for population in self.populations.values() if population.entity.plural == plural), None) + variable = self.tax_benefit_system.get_variable( + variable_name, + check_existence=True, + ) + + if variable is None: + raise errors.VariableNotFoundError(variable_name, self.tax_benefit_system) - def get_entity(self, plural = None): + return self.populations[variable.entity.key] + + def get_population(self, plural: str | None = None) -> Population | None: + return next( + ( + population + for population in self.populations.values() + if population.entity.plural == plural + ), + None, + ) + + def get_entity( + self, + plural: str | None = None, + ) -> Population | None: population = self.get_population(plural) return population and population.entity def describe_entities(self): - return {population.entity.plural: population.ids for population in self.populations.values()} + return { + population.entity.plural: population.ids + for population in self.populations.values() + } - def clone(self, debug = False, trace = False): - """ - Copy the simulation just enough to be able to run the copy without modifying the original simulation - """ + def clone(self, debug=False, trace=False): + """Copy the simulation just enough to be able to run the copy without modifying the original simulation.""" new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ('debug', 'trace', 'tracer'): + if key not in ("debug", "trace", "tracer"): new_dict[key] = value new.persons = self.persons.clone(new) @@ -442,9 +587,18 @@ def clone(self, debug = False, trace = False): for entity in self.tax_benefit_system.group_entities: population = self.populations[entity.key].clone(new) new.populations[entity.key] = population - setattr(new, entity.key, population) # create shortcut simulation.household (for instance) + setattr( + new, + entity.key, + population, + ) # create shortcut simulation.household (for instance) new.debug = debug new.trace = trace return new + + +class Cache(NamedTuple): + variable: str + period: periods.Period diff --git a/openfisca_core/simulations/simulation_builder.py b/openfisca_core/simulations/simulation_builder.py index e331accf02..064b5b4cb6 100644 --- a/openfisca_core/simulations/simulation_builder.py +++ b/openfisca_core/simulations/simulation_builder.py @@ -1,227 +1,427 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import NoReturn + import copy -import dpath -import typing +import dpath.util import numpy -from openfisca_core import periods -from openfisca_core.entities import Entity -from openfisca_core.errors import PeriodMismatchError, SituationParsingError, VariableNotFoundError -from openfisca_core.populations import Population -from openfisca_core.simulations import helpers, Simulation -from openfisca_core.variables import Variable +from openfisca_core import entities, errors, periods, populations, variables + +from . import helpers +from ._build_default_simulation import _BuildDefaultSimulation +from ._build_from_variables import _BuildFromVariables +from ._type_guards import ( + are_entities_fully_specified, + are_entities_short_form, + are_entities_specified, + has_axes, +) +from .simulation import Simulation +from .typing import ( + Axis, + Entity, + FullySpecifiedEntities, + GroupEntities, + GroupEntity, + ImplicitGroupEntities, + Params, + ParamsWithoutAxes, + Population, + Role, + SingleEntity, + TaxBenefitSystem, + Variables, +) class SimulationBuilder: - - def __init__(self): - self.default_period = None # Simulation period used for variables when no period is defined - self.persons_plural = None # Plural name for person entity in current tax and benefits system + def __init__(self) -> None: + self.default_period = ( + None # Simulation period used for variables when no period is defined + ) + self.persons_plural = ( + None # Plural name for person entity in current tax and benefits system + ) # JSON input - Memory of known input values. Indexed by variable or axis name. - self.input_buffer: typing.Dict[Variable.name, typing.Dict[str(periods.period), numpy.array]] = {} - self.populations: typing.Dict[Entity.key, Population] = {} + self.input_buffer: dict[ + variables.Variable.name, + dict[str(periods.period), numpy.array], + ] = {} + self.populations: dict[entities.Entity.key, populations.Population] = {} # JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes. - self.entity_counts: typing.Dict[Entity.plural, int] = {} - # JSON input - typing.List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. - self.entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} + self.entity_counts: dict[entities.Entity.plural, int] = {} + # JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``. + self.entity_ids: dict[entities.Entity.plural, list[int]] = {} # Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id) - self.memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.memberships: dict[entities.Entity.plural, list[int]] = {} + self.roles: dict[entities.Entity.plural, list[int]] = {} - self.variable_entities: typing.Dict[Variable.name, Entity] = {} + self.variable_entities: dict[variables.Variable.name, entities.Entity] = {} self.axes = [[]] - self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} - self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {} + self.axes_entity_counts: dict[entities.Entity.plural, int] = {} + self.axes_entity_ids: dict[entities.Entity.plural, list[int]] = {} + self.axes_memberships: dict[entities.Entity.plural, list[int]] = {} + self.axes_roles: dict[entities.Entity.plural, list[int]] = {} + + def build_from_dict( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Params, + ) -> Simulation: + """Build a simulation from an input dictionary. + + This method uses :meth:`.SimulationBuilder.build_from_entities` if + entities are fully specified, or + :meth:`.SimulationBuilder.build_from_variables` if they are not. + + Args: + tax_benefit_system: The system to use. + input_dict: The input of the simulation. + + Returns: + Simulation: The built simulation. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> entities = {"persons", "households"} + + >>> params = { + ... "axes": [ + ... [ + ... { + ... "count": 2, + ... "max": 3000, + ... "min": 0, + ... "name": "rent", + ... "period": "2018-11", + ... } + ... ] + ... ], + ... "households": { + ... "housea": {"parents": ["Alicia", "Javier"]}, + ... "houseb": {"parents": ["Tom"]}, + ... }, + ... "persons": { + ... "Alicia": {"salary": {"2018-11": 0}}, + ... "Javier": {}, + ... "Tom": {}, + ... }, + ... } + + >>> are_entities_short_form(params, entities) + True + + >>> params = {"salary": [12000, 13000]} + + >>> not are_entities_specified(params, {"salary"}) + True - def build_from_dict(self, tax_benefit_system, input_dict): """ - Build a simulation from ``input_dict`` + #: The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() - This method uses :any:`build_from_entities` if entities are fully specified, or :any:`build_from_variables` if not. + #: The singular names of the entities in the tax and benefits system. + singular: Iterable[str] = tax_benefit_system.entities_by_singular() - :param dict input_dict: A dict represeting the input of the simulation - :return: A :any:`Simulation` - """ + #: The names of the variables in the tax and benefits system. + variables: Iterable[str] = tax_benefit_system.variables.keys() - input_dict = self.explicit_singular_entities(tax_benefit_system, input_dict) - if any(key in tax_benefit_system.entities_plural() for key in input_dict.keys()): - return self.build_from_entities(tax_benefit_system, input_dict) - else: - return self.build_from_variables(tax_benefit_system, input_dict) + if are_entities_short_form(input_dict, singular): + params = self.explicit_singular_entities(tax_benefit_system, input_dict) + return self.build_from_entities(tax_benefit_system, params) - def build_from_entities(self, tax_benefit_system, input_dict): - """ - Build a simulation from a Python dict ``input_dict`` fully specifying entities. + if are_entities_fully_specified(params := input_dict, plural): + return self.build_from_entities(tax_benefit_system, params) - Examples: + if not are_entities_specified(params := input_dict, variables): + return self.build_from_variables(tax_benefit_system, params) + return None + + def build_from_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: FullySpecifiedEntities, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` fully specifying + entities. + + Examples: + >>> entities = {"person", "household"} + + >>> params = { + ... "persons": {"Javier": {"salary": {"2018-11": 2000}}}, + ... "household": {"parents": ["Javier"]}, + ... "axes": [[{"count": 1, "max": 1, "min": 1, "name": "household"}]], + ... } + + >>> are_entities_short_form(params, entities) + True - >>> simulation_builder.build_from_entities({ - 'persons': {'Javier': { 'salary': {'2018-11': 2000}}}, - 'households': {'household': {'parents': ['Javier']}} - }) """ + # Create the populations + populations = tax_benefit_system.instantiate_entities() + + # Create the simulation + simulation = Simulation(tax_benefit_system, populations) + + # Why? input_dict = copy.deepcopy(input_dict) - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) + # The plural names of the entities in the tax and benefits system. + plural: Iterable[str] = tax_benefit_system.entities_plural() # Register variables so get_variable_entity can find them - for (variable_name, _variable) in tax_benefit_system.variables.items(): - self.register_variable(variable_name, simulation.get_variable_population(variable_name).entity) - - helpers.check_type(input_dict, dict, ['error']) - axes = input_dict.pop('axes', None) - - unexpected_entities = [entity for entity in input_dict if entity not in tax_benefit_system.entities_plural()] - if unexpected_entities: - unexpected_entity = unexpected_entities[0] - raise SituationParsingError([unexpected_entity], - ''.join([ - "Some entities in the situation are not defined in the loaded tax and benefit system.", - "These entities are not found: {0}.", - "The defined entities are: {1}."] - ) - .format( - ', '.join(unexpected_entities), - ', '.join(tax_benefit_system.entities_plural()) - ) - ) - persons_json = input_dict.get(tax_benefit_system.person_entity.plural, None) + self.register_variables(simulation) + + # Declare axes + axes: list[list[Axis]] | None = None + + # ? + helpers.check_type(input_dict, dict, ["error"]) + + # Remove axes from input_dict + params: ParamsWithoutAxes = { + key: value for key, value in input_dict.items() if key != "axes" + } + + # Save axes for later + if has_axes(axes_params := input_dict): + axes = copy.deepcopy(axes_params.get("axes", None)) + + # Check for unexpected entities + helpers.check_unexpected_entities(params, plural) + + person_entity: SingleEntity = tax_benefit_system.person_entity + + persons_json = params.get(person_entity.plural, None) if not persons_json: - raise SituationParsingError([tax_benefit_system.person_entity.plural], - 'No {0} found. At least one {0} must be defined to run a simulation.'.format(tax_benefit_system.person_entity.key)) + raise errors.SituationParsingError( + [person_entity.plural], + f"No {person_entity.key} found. At least one {person_entity.key} must be defined to run a simulation.", + ) persons_ids = self.add_person_entity(simulation.persons.entity, persons_json) for entity_class in tax_benefit_system.group_entities: - instances_json = input_dict.get(entity_class.plural) + instances_json = params.get(entity_class.plural) + if instances_json is not None: - self.add_group_entity(self.persons_plural, persons_ids, entity_class, instances_json) + self.add_group_entity( + self.persons_plural, + persons_ids, + entity_class, + instances_json, + ) + + elif axes is not None: + message = ( + f"We could not find any specified {entity_class.plural}. " + "In order to expand over axes, all group entities and roles " + "must be fully specified. For further support, please do " + "not hesitate to take a look at the official documentation: " + "https://openfisca.org/doc/simulate/replicate-simulation-inputs.html." + ) + + raise errors.SituationParsingError([entity_class.plural], message) + else: self.add_default_group_entity(persons_ids, entity_class) - if axes: - self.axes = axes + if axes is not None: + for axis in axes[0]: + self.add_parallel_axis(axis) + + if len(axes) >= 1: + for axis in axes[1:]: + self.add_perpendicular_axis(axis[0]) + self.expand_axes() try: self.finalize_variables_init(simulation.persons) - except PeriodMismatchError as e: + except errors.PeriodMismatchError as e: self.raise_period_mismatch(simulation.persons.entity, persons_json, e) for entity_class in tax_benefit_system.group_entities: try: population = simulation.populations[entity_class.key] self.finalize_variables_init(population) - except PeriodMismatchError as e: + except errors.PeriodMismatchError as e: self.raise_period_mismatch(population.entity, instances_json, e) return simulation - def build_from_variables(self, tax_benefit_system, input_dict): - """ - Build a simulation from a Python dict ``input_dict`` describing variables values without expliciting entities. + def build_from_variables( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: Variables, + ) -> Simulation: + """Build a simulation from a Python dict ``input_dict`` describing + variables values without expliciting entities. - This method uses :any:`build_default_simulation` to infer an entity structure + This method uses :meth:`.SimulationBuilder.build_default_simulation` to + infer an entity structure. - Example: + Args: + tax_benefit_system: The system to use. + input_dict: The input of the simulation. - >>> simulation_builder.build_from_variables( - {'salary': {'2016-10': 12000}} - ) - """ - count = helpers._get_person_count(input_dict) - simulation = self.build_default_simulation(tax_benefit_system, count) - for variable, value in input_dict.items(): - if not isinstance(value, dict): - if self.default_period is None: - raise SituationParsingError([variable], - "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.") - simulation.set_input(variable, self.default_period, value) - else: - for period_str, dated_value in value.items(): - simulation.set_input(variable, period_str, dated_value) - return simulation + Returns: + Simulation: The built simulation. - def build_default_simulation(self, tax_benefit_system, count = 1): - """ - Build a simulation where: - - There are ``count`` persons - - There are ``count`` instances of each group entity, containing one person - - Every person has, in each entity, the first role - """ + Raises: + SituationParsingError: If the input is not valid. - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) - for population in simulation.populations.values(): - population.count = count - population.ids = numpy.array(range(count)) - if not population.entity.is_person: - population.members_entity_id = population.ids # Each person is its own group entity - return simulation + Examples: + >>> params = {"salary": {"2016-10": 12000}} + + >>> are_entities_specified(params, {"salary"}) + False + + >>> params = {"salary": 12000} - def create_entities(self, tax_benefit_system): + >>> are_entities_specified(params, {"salary"}) + False + + """ + return ( + _BuildFromVariables(tax_benefit_system, input_dict, self.default_period) + .add_dated_values() + .add_undated_values() + .simulation + ) + + @staticmethod + def build_default_simulation( + tax_benefit_system: TaxBenefitSystem, + count: int = 1, + ) -> Simulation: + """Build a default simulation. + + Where: + - There are ``count`` persons + - There are ``count`` of each group entity, containing one person + - Every person has, in each entity, the first role + + """ + return ( + _BuildDefaultSimulation(tax_benefit_system, count) + .add_count() + .add_ids() + .add_members_entity_id() + .simulation + ) + + def create_entities(self, tax_benefit_system) -> None: self.populations = tax_benefit_system.instantiate_entities() - def declare_person_entity(self, person_singular, persons_ids: typing.Iterable): + def declare_person_entity(self, person_singular, persons_ids: Iterable) -> None: person_instance = self.populations[person_singular] person_instance.ids = numpy.array(list(persons_ids)) person_instance.count = len(person_instance.ids) self.persons_plural = person_instance.entity.plural - def declare_entity(self, entity_singular, entity_ids: typing.Iterable): + def declare_entity(self, entity_singular, entity_ids: Iterable): entity_instance = self.populations[entity_singular] entity_instance.ids = numpy.array(list(entity_ids)) entity_instance.count = len(entity_instance.ids) return entity_instance - def nb_persons(self, entity_singular, role = None): - return self.populations[entity_singular].nb_persons(role = role) + def nb_persons(self, entity_singular, role=None): + return self.populations[entity_singular].nb_persons(role=role) - def join_with_persons(self, group_population, persons_group_assignment, roles: typing.Iterable[str]): + def join_with_persons( + self, + group_population, + persons_group_assignment, + roles: Iterable[str], + ) -> None: # Maps group's identifiers to a 0-based integer range, for indexing into members_roles (see PR#876) - group_sorted_indices = numpy.unique(persons_group_assignment, return_inverse = True)[1] - group_population.members_entity_id = numpy.argsort(group_population.ids)[group_sorted_indices] + group_sorted_indices = numpy.unique( + persons_group_assignment, + return_inverse=True, + )[1] + group_population.members_entity_id = numpy.argsort(group_population.ids)[ + group_sorted_indices + ] flattened_roles = group_population.entity.flattened_roles roles_array = numpy.array(roles) if numpy.issubdtype(roles_array.dtype, numpy.integer): group_population.members_role = numpy.array(flattened_roles)[roles_array] + elif len(flattened_roles) == 0: + group_population.members_role = numpy.int16(0) else: - if len(flattened_roles) == 0: - group_population.members_role = numpy.int64(0) - else: - group_population.members_role = numpy.select([roles_array == role.key for role in flattened_roles], flattened_roles) + group_population.members_role = numpy.select( + [roles_array == role.key for role in flattened_roles], + flattened_roles, + ) def build(self, tax_benefit_system): return Simulation(tax_benefit_system, self.populations) - def explicit_singular_entities(self, tax_benefit_system, input_dict): - """ - Preprocess ``input_dict`` to explicit entities defined using the single-entity shortcut + def explicit_singular_entities( + self, + tax_benefit_system: TaxBenefitSystem, + input_dict: ImplicitGroupEntities, + ) -> GroupEntities: + """Preprocess ``input_dict`` to explicit entities defined using the + single-entity shortcut. - Example: + Examples: + >>> params = { + ... "persons": { + ... "Javier": {}, + ... }, + ... "household": {"parents": ["Javier"]}, + ... } - >>> simulation_builder.explicit_singular_entities( - {'persons': {'Javier': {}, }, 'household': {'parents': ['Javier']}} - ) - >>> {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}} - """ + >>> are_entities_fully_specified(params, {"persons", "households"}) + False + + >>> are_entities_short_form(params, {"person", "household"}) + True + + >>> params = { + ... "persons": {"Javier": {}}, + ... "households": {"household": {"parents": ["Javier"]}}, + ... } + + >>> are_entities_fully_specified(params, {"persons", "households"}) + True - singular_keys = set(input_dict).intersection(tax_benefit_system.entities_by_singular()) - if not singular_keys: - return input_dict + >>> are_entities_short_form(params, {"person", "household"}) + False + + """ + singular_keys = set(input_dict).intersection( + tax_benefit_system.entities_by_singular(), + ) result = { entity_id: entity_description for (entity_id, entity_description) in input_dict.items() if entity_id in tax_benefit_system.entities_plural() - } # filter out the singular entities + } # filter out the singular entities for singular in singular_keys: plural = tax_benefit_system.entities_by_singular()[singular].plural @@ -230,9 +430,7 @@ def explicit_singular_entities(self, tax_benefit_system, input_dict): return result def add_person_entity(self, entity, instances_json): - """ - Add the simulation's instances of the persons entity as described in ``instances_json``. - """ + """Add the simulation's instances of the persons entity as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) self.persons_plural = entity.plural @@ -245,17 +443,28 @@ def add_person_entity(self, entity, instances_json): return self.get_ids(entity.plural) - def add_default_group_entity(self, persons_ids, entity): + def add_default_group_entity( + self, + persons_ids: list[str], + entity: GroupEntity, + ) -> None: persons_count = len(persons_ids) + roles = list(entity.flattened_roles) self.entity_ids[entity.plural] = persons_ids self.entity_counts[entity.plural] = persons_count - self.memberships[entity.plural] = numpy.arange(0, persons_count, dtype = numpy.int32) - self.roles[entity.plural] = numpy.repeat(entity.flattened_roles[0], persons_count) - - def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): - """ - Add all instances of one of the model's entities as described in ``instances_json``. - """ + self.memberships[entity.plural] = list( + numpy.arange(0, persons_count, dtype=numpy.int32), + ) + self.roles[entity.plural] = [roles[0]] * persons_count + + def add_group_entity( + self, + persons_plural: str, + persons_ids: list[str], + entity: GroupEntity, + instances_json, + ) -> None: + """Add all instances of one of the model's entities as described in ``instances_json``.""" helpers.check_type(instances_json, dict, [entity.plural]) entity_ids = list(map(str, instances_json.keys())) @@ -264,8 +473,8 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): persons_count = len(persons_ids) persons_to_allocate = set(persons_ids) - self.memberships[entity.plural] = numpy.empty(persons_count, dtype = numpy.int32) - self.roles[entity.plural] = numpy.empty(persons_count, dtype = object) + self.memberships[entity.plural] = numpy.empty(persons_count, dtype=numpy.int32) + self.roles[entity.plural] = numpy.empty(persons_count, dtype=object) self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) @@ -276,18 +485,31 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): variables_json = instance_object.copy() # Don't mutate function input roles_json = { - role.plural or role.key: helpers.transform_to_strict_syntax(variables_json.pop(role.plural or role.key, [])) + role.plural + or role.key: helpers.transform_to_strict_syntax( + variables_json.pop(role.plural or role.key, []), + ) for role in entity.roles - } + } for role_id, role_definition in roles_json.items(): - helpers.check_type(role_definition, list, [entity.plural, instance_id, role_id]) + helpers.check_type( + role_definition, + list, + [entity.plural, instance_id, role_id], + ) for index, person_id in enumerate(role_definition): entity_plural = entity.plural - self.check_persons_to_allocate(persons_plural, entity_plural, - persons_ids, - person_id, instance_id, role_id, - persons_to_allocate, index) + self.check_persons_to_allocate( + persons_plural, + entity_plural, + persons_ids, + person_id, + instance_id, + role_id, + persons_to_allocate, + index, + ) persons_to_allocate.discard(person_id) @@ -298,12 +520,17 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): role = role_by_plural[role_plural] if role.max is not None and len(persons_with_role) > role.max: - raise SituationParsingError([entity.plural, instance_id, role_plural], f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.") + raise errors.SituationParsingError( + [entity.plural, instance_id, role_plural], + f"There can be at most {role.max} {role_plural} in a {entity.key}. {len(persons_with_role)} were declared in '{instance_id}'.", + ) for index_within_role, person_id in enumerate(persons_with_role): person_index = persons_ids.index(person_id) self.memberships[entity.plural][person_index] = entity_index - person_role = role.subroles[index_within_role] if role.subroles else role + person_role = ( + role.subroles[index_within_role] if role.subroles else role + ) self.roles[entity.plural][person_index] = person_role self.init_variable_values(entity, variables_json, instance_id) @@ -312,7 +539,9 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): entity_ids = entity_ids + list(persons_to_allocate) for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) - self.memberships[entity.plural][person_index] = entity_ids.index(person_id) + self.memberships[entity.plural][person_index] = entity_ids.index( + person_id, + ) self.roles[entity.plural][person_index] = entity.flattened_roles[0] # Adjust previously computed ids and counts self.entity_ids[entity.plural] = entity_ids @@ -322,61 +551,87 @@ def add_group_entity(self, persons_plural, persons_ids, entity, instances_json): self.roles[entity.plural] = self.roles[entity.plural].tolist() self.memberships[entity.plural] = self.memberships[entity.plural].tolist() - def set_default_period(self, period_str): + def set_default_period(self, period_str) -> None: if period_str: self.default_period = str(periods.period(period_str)) - def get_input(self, variable, period_str): + def get_input(self, variable: str, period_str: str) -> Array | None: if variable not in self.input_buffer: self.input_buffer[variable] = {} + return self.input_buffer[variable].get(period_str) - def check_persons_to_allocate(self, persons_plural, entity_plural, - persons_ids, - person_id, entity_id, role_id, - persons_to_allocate, index): - helpers.check_type(person_id, str, [entity_plural, entity_id, role_id, str(index)]) + def check_persons_to_allocate( + self, + persons_plural, + entity_plural, + persons_ids, + person_id, + entity_id, + role_id, + persons_to_allocate, + index, + ) -> None: + helpers.check_type( + person_id, + str, + [entity_plural, entity_id, role_id, str(index)], + ) if person_id not in persons_ids: - raise SituationParsingError([entity_plural, entity_id, role_id], - "Unexpected value: {0}. {0} has been declared in {1} {2}, but has not been declared in {3}.".format( - person_id, entity_id, role_id, persons_plural) - ) + raise errors.SituationParsingError( + [entity_plural, entity_id, role_id], + f"Unexpected value: {person_id}. {person_id} has been declared in {entity_id} {role_id}, but has not been declared in {persons_plural}.", + ) if person_id not in persons_to_allocate: - raise SituationParsingError([entity_plural, entity_id, role_id], - "{} has been declared more than once in {}".format( - person_id, entity_plural) - ) - - def init_variable_values(self, entity, instance_object, instance_id): - query = entity.variables.isdefined() + raise errors.SituationParsingError( + [entity_plural, entity_id, role_id], + f"{person_id} has been declared more than once in {entity_plural}", + ) + def init_variable_values(self, entity, instance_object, instance_id) -> None: for variable_name, variable_values in instance_object.items(): path_in_json = [entity.plural, instance_id, variable_name] - try: - variable = query.get(variable_name) + entity.check_variable_defined_for_entity(variable_name) except ValueError as e: # The variable is defined for another entity - raise SituationParsingError(path_in_json, e.args[0]) - except VariableNotFoundError as e: # The variable doesn't exist - raise SituationParsingError(path_in_json, str(e), code = 404) + raise errors.SituationParsingError(path_in_json, e.args[0]) + except errors.VariableNotFoundError as e: # The variable doesn't exist + raise errors.SituationParsingError(path_in_json, str(e), code=404) instance_index = self.get_ids(entity.plural).index(instance_id) if not isinstance(variable_values, dict): if self.default_period is None: - raise SituationParsingError(path_in_json, - "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.") + raise errors.SituationParsingError( + path_in_json, + "Can't deal with type: expected object. Input variables should be set for specific periods. For instance: {'salary': {'2017-01': 2000, '2017-02': 2500}}, or {'birth_date': {'ETERNITY': '1980-01-01'}}.", + ) variable_values = {self.default_period: variable_values} for period_str, value in variable_values.items(): try: periods.period(period_str) except ValueError as e: - raise SituationParsingError(path_in_json, e.args[0]) - - self.add_variable_value(entity, variable, instance_index, instance_id, period_str, value) + raise errors.SituationParsingError(path_in_json, e.args[0]) + variable = entity.get_variable(variable_name) + self.add_variable_value( + entity, + variable, + instance_index, + instance_id, + period_str, + value, + ) - def add_variable_value(self, entity, variable, instance_index, instance_id, period_str, value): + def add_variable_value( + self, + entity, + variable, + instance_index, + instance_id, + period_str, + value, + ) -> None: path_in_json = [entity.plural, instance_id, variable.name, period_str] if value is None: @@ -391,13 +646,13 @@ def add_variable_value(self, entity, variable, instance_index, instance_id, peri try: value = variable.check_set_value(value) except ValueError as error: - raise SituationParsingError(path_in_json, *error.args) + raise errors.SituationParsingError(path_in_json, *error.args) array[instance_index] = value self.input_buffer[variable.name][str(periods.period(period_str))] = array - def finalize_variables_init(self, population): + def finalize_variables_init(self, population) -> None: # Due to set_input mechanism, we must bufferize all inputs, then actually set them, # so that the months are set first and the years last. plural_key = population.entity.plural @@ -407,15 +662,18 @@ def finalize_variables_init(self, population): if plural_key in self.memberships: population.members_entity_id = numpy.array(self.get_memberships(plural_key)) population.members_role = numpy.array(self.get_roles(plural_key)) - for variable_name in self.input_buffer.keys(): + for variable_name in self.input_buffer: try: holder = population.get_holder(variable_name) except ValueError: # Wrong entity, we can just ignore that continue buffer = self.input_buffer[variable_name] - unsorted_periods = [periods.period(period_str) for period_str in self.input_buffer[variable_name].keys()] + unsorted_periods = [ + periods.period(period_str) + for period_str in self.input_buffer[variable_name] + ] # We need to handle small periods first for set_input to work - sorted_periods = sorted(unsorted_periods, key = periods.key_period_size) + sorted_periods = sorted(unsorted_periods, key=periods.key_period_size) for period_value in sorted_periods: values = buffer[str(period_value)] # Hack to replicate the values in the persons entity @@ -427,66 +685,85 @@ def finalize_variables_init(self, population): if (variable.end is None) or (period_value.start.date <= variable.end): holder.set_input(period_value, array) - def raise_period_mismatch(self, entity, json, e): + def raise_period_mismatch(self, entity, json, e) -> NoReturn: # This error happens when we try to set a variable value for a period that doesn't match its definition period # It is only raised when we consume the buffer. We thus don't know which exact key caused the error. # We do a basic research to find the culprit path culprit_path = next( - dpath.search(json, "*/{}/{}".format(e.variable_name, str(e.period)), yielded = True), - None) + dpath.util.search( + json, + f"*/{e.variable_name}/{e.period!s}", + yielded=True, + ), + None, + ) if culprit_path: - path = [entity.plural] + culprit_path[0].split('/') + path = [entity.plural, *culprit_path[0].split("/")] else: - path = [entity.plural] # Fallback: if we can't find the culprit, just set the error at the entities level + path = [ + entity.plural, + ] # Fallback: if we can't find the culprit, just set the error at the entities level - raise SituationParsingError(path, e.message) + raise errors.SituationParsingError(path, e.message) # Returns the total number of instances of this entity, including when there is replication along axes - def get_count(self, entity_name): + def get_count(self, entity_name: str) -> int: return self.axes_entity_counts.get(entity_name, self.entity_counts[entity_name]) # Returns the ids of instances of this entity, including when there is replication along axes - def get_ids(self, entity_name): + def get_ids(self, entity_name: str) -> list[str]: return self.axes_entity_ids.get(entity_name, self.entity_ids[entity_name]) # Returns the memberships of individuals in this entity, including when there is replication along axes def get_memberships(self, entity_name): # Return empty array for the "persons" entity - return self.axes_memberships.get(entity_name, self.memberships.get(entity_name, [])) + return self.axes_memberships.get( + entity_name, + self.memberships.get(entity_name, []), + ) # Returns the roles of individuals in this entity, including when there is replication along axes - def get_roles(self, entity_name): + def get_roles(self, entity_name: str) -> Sequence[Role]: # Return empty array for the "persons" entity return self.axes_roles.get(entity_name, self.roles.get(entity_name, [])) - def add_parallel_axis(self, axis): + def add_parallel_axis(self, axis: Axis) -> None: # All parallel axes have the same count and entity. # Search for a compatible axis, if none exists, error out self.axes[0].append(axis) - def add_perpendicular_axis(self, axis): + def add_perpendicular_axis(self, axis: Axis) -> None: # This adds an axis perpendicular to all previous dimensions self.axes.append([axis]) - def expand_axes(self): + def expand_axes(self) -> None: # This method should be idempotent & allow change in axes - perpendicular_dimensions = self.axes + perpendicular_dimensions: list[list[Axis]] = self.axes + cell_count: int = 1 - cell_count = 1 for parallel_axes in perpendicular_dimensions: - first_axis = parallel_axes[0] - axis_count = first_axis['count'] + first_axis: Axis = parallel_axes[0] + axis_count: int = first_axis["count"] cell_count *= axis_count # Scale the "prototype" situation, repeating it cell_count times - for entity_name in self.entity_counts.keys(): + for entity_name in self.entity_counts: # Adjust counts - self.axes_entity_counts[entity_name] = self.get_count(entity_name) * cell_count + self.axes_entity_counts[entity_name] = ( + self.get_count(entity_name) * cell_count + ) # Adjust ids - original_ids = self.get_ids(entity_name) * cell_count - indices = numpy.arange(0, cell_count * self.entity_counts[entity_name]) - adjusted_ids = [id + str(ix) for id, ix in zip(original_ids, indices)] + original_ids: list[str] = self.get_ids(entity_name) * cell_count + indices: Array[numpy.int16] = numpy.arange( + 0, + cell_count * self.entity_counts[entity_name], + ) + adjusted_ids: list[str] = [ + original_id + str(index) + for original_id, index in zip(original_ids, indices) + ] self.axes_entity_ids[entity_name] = adjusted_ids + # Adjust roles original_roles = self.get_roles(entity_name) adjusted_roles = original_roles * cell_count @@ -495,8 +772,13 @@ def expand_axes(self): if entity_name != self.persons_plural: original_memberships = self.get_memberships(entity_name) repeated_memberships = original_memberships * cell_count - indices = numpy.repeat(numpy.arange(0, cell_count), len(original_memberships)) * self.entity_counts[entity_name] - adjusted_memberships = (numpy.array(repeated_memberships) + indices).tolist() + indices = ( + numpy.repeat(numpy.arange(0, cell_count), len(original_memberships)) + * self.entity_counts[entity_name] + ) + adjusted_memberships = ( + numpy.array(repeated_memberships) + indices + ).tolist() self.axes_memberships[entity_name] = adjusted_memberships # Now generate input values along the specified axes @@ -504,61 +786,72 @@ def expand_axes(self): if len(self.axes) == 1 and len(self.axes[0]): parallel_axes = self.axes[0] first_axis = parallel_axes[0] - axis_count: int = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + axis_count: int = first_axis["count"] + axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along axes for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis.get('period', self.default_period) - axis_name = axis['name'] + axis_index = axis.get("index", 0) + axis_period = axis.get("period", self.default_period) + axis_name = axis["name"] variable = axis_entity.get_variable(axis_name) array = self.get_input(axis_name, str(axis_period)) if array is None: array = variable.default_array(axis_count * axis_entity_step_size) elif array.size == axis_entity_step_size: array = numpy.tile(array, axis_count) - array[axis_index:: axis_entity_step_size] = numpy.linspace( - axis['min'], - axis['max'], - num = axis_count, - ) + array[axis_index::axis_entity_step_size] = numpy.linspace( + axis["min"], + axis["max"], + num=axis_count, + ) # Set input self.input_buffer[axis_name][str(axis_period)] = array else: - first_axes_count: typing.List[int] = ( - parallel_axes[0]["count"] - for parallel_axes - in self.axes - ) + first_axes_count: list[int] = ( + parallel_axes[0]["count"] for parallel_axes in self.axes + ) axes_linspaces = [ - numpy.linspace(0, axis_count - 1, num = axis_count) - for axis_count - in first_axes_count - ] + numpy.linspace(0, axis_count - 1, num=axis_count) + for axis_count in first_axes_count + ] axes_meshes = numpy.meshgrid(*axes_linspaces) for parallel_axes, mesh in zip(self.axes, axes_meshes): first_axis = parallel_axes[0] - axis_count = first_axis['count'] - axis_entity = self.get_variable_entity(first_axis['name']) + axis_count = first_axis["count"] + axis_entity = self.get_variable_entity(first_axis["name"]) axis_entity_step_size = self.entity_counts[axis_entity.plural] # Distribute values along the grid for axis in parallel_axes: - axis_index = axis.get('index', 0) - axis_period = axis['period'] or self.default_period - axis_name = axis['name'] - variable = axis_entity.get_variable(axis_name) + axis_index = axis.get("index", 0) + axis_period = axis.get("period", self.default_period) + axis_name = axis["name"] + variable = axis_entity.get_variable(axis_name, check_existence=True) array = self.get_input(axis_name, str(axis_period)) if array is None: - array = variable.default_array(cell_count * axis_entity_step_size) + array = variable.default_array( + cell_count * axis_entity_step_size, + ) elif array.size == axis_entity_step_size: array = numpy.tile(array, cell_count) - array[axis_index:: axis_entity_step_size] = axis['min'] \ - + mesh.reshape(cell_count) * (axis['max'] - axis['min']) / (axis_count - 1) + array[axis_index::axis_entity_step_size] = axis[ + "min" + ] + mesh.reshape(cell_count) * (axis["max"] - axis["min"]) / ( + axis_count - 1 + ) self.input_buffer[axis_name][str(axis_period)] = array - def get_variable_entity(self, variable_name): + def get_variable_entity(self, variable_name: str) -> Entity: return self.variable_entities[variable_name] - def register_variable(self, variable_name, entity): + def register_variable(self, variable_name: str, entity: Entity) -> None: self.variable_entities[variable_name] = entity + + def register_variables(self, simulation: Simulation) -> None: + tax_benefit_system: TaxBenefitSystem = simulation.tax_benefit_system + variables: Iterable[str] = tax_benefit_system.variables.keys() + + for name in variables: + population: Population = simulation.get_variable_population(name) + entity: Entity = population.entity + self.register_variable(name, entity) diff --git a/openfisca_core/simulations/typing.py b/openfisca_core/simulations/typing.py new file mode 100644 index 0000000000..8091994e53 --- /dev/null +++ b/openfisca_core/simulations/typing.py @@ -0,0 +1,203 @@ +"""Type aliases of OpenFisca models to use in the context of simulations.""" + +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from numpy.typing import NDArray as Array +from typing import Protocol, TypeVar, TypedDict, Union +from typing_extensions import NotRequired, Required, TypeAlias + +import datetime +from abc import abstractmethod + +from numpy import ( + bool_ as Bool, + datetime64 as Date, + float32 as Float, + int16 as Enum, + int32 as Int, + str_ as String, +) + +#: Generic type variables. +E = TypeVar("E") +G = TypeVar("G", covariant=True) +T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True) +U = TypeVar("U", bool, datetime.date, float, str) +V = TypeVar("V", covariant=True) + + +#: Type alias for a simulation dictionary defining the roles. +Roles: TypeAlias = dict[str, Union[str, Iterable[str]]] + +#: Type alias for a simulation dictionary with undated variables. +UndatedVariable: TypeAlias = dict[str, object] + +#: Type alias for a simulation dictionary with dated variables. +DatedVariable: TypeAlias = dict[str, UndatedVariable] + +#: Type alias for a simulation dictionary with abbreviated entities. +Variables: TypeAlias = dict[str, Union[UndatedVariable, DatedVariable]] + +#: Type alias for a simulation with fully specified single entities. +SingleEntities: TypeAlias = dict[str, dict[str, Variables]] + +#: Type alias for a simulation dictionary with implicit group entities. +ImplicitGroupEntities: TypeAlias = dict[str, Union[Roles, Variables]] + +#: Type alias for a simulation dictionary with explicit group entities. +GroupEntities: TypeAlias = dict[str, ImplicitGroupEntities] + +#: Type alias for a simulation dictionary with fully specified entities. +FullySpecifiedEntities: TypeAlias = Union[SingleEntities, GroupEntities] + +#: Type alias for a simulation dictionary with axes parameters. +Axes: TypeAlias = dict[str, Iterable[Iterable["Axis"]]] + +#: Type alias for a simulation dictionary without axes parameters. +ParamsWithoutAxes: TypeAlias = Union[ + Variables, + ImplicitGroupEntities, + FullySpecifiedEntities, +] + +#: Type alias for a simulation dictionary with axes parameters. +ParamsWithAxes: TypeAlias = Union[Axes, ParamsWithoutAxes] + +#: Type alias for a simulation dictionary with all the possible scenarios. +Params: TypeAlias = ParamsWithAxes + + +class Axis(TypedDict, total=False): + """Interface representing an axis of a simulation.""" + + count: Required[int] + index: NotRequired[int] + max: Required[float] + min: Required[float] + name: Required[str] + period: NotRequired[str | int] + + +class Entity(Protocol): + """Interface representing an entity of a simulation.""" + + key: str + plural: str | None + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> Variable[T] | None: + """Get a variable.""" + + +class SingleEntity(Entity, Protocol): + """Interface representing a single entity of a simulation.""" + + +class GroupEntity(Entity, Protocol): + """Interface representing a group entity of a simulation.""" + + @property + @abstractmethod + def flattened_roles(self) -> Iterable[Role[G]]: + """Get the flattened roles of the GroupEntity.""" + + +class Holder(Protocol[V]): + """Interface representing a holder of a simulation's computed values.""" + + @property + @abstractmethod + def variable(self) -> Variable[T]: + """Get the Variable of the Holder.""" + + def get_array(self, __period: str) -> Array[T] | None: + """Get the values of the Variable for a given Period.""" + + def set_input( + self, + __period: Period, + __array: Array[T] | Sequence[U], + ) -> Array[T] | None: + """Set values for a Variable for a given Period.""" + + +class Period(Protocol): + """Interface representing a period of a simulation.""" + + +class Population(Protocol[E]): + """Interface representing a data vector of an Entity.""" + + count: int + entity: E + ids: Array[String] + + def get_holder(self, __variable_name: str) -> Holder[V]: + """Get the holder of a Variable.""" + + +class SinglePopulation(Population[E], Protocol): + """Interface representing a data vector of a SingleEntity.""" + + +class GroupPopulation(Population[E], Protocol): + """Interface representing a data vector of a GroupEntity.""" + + members_entity_id: Array[String] + + def nb_persons(self, __role: Role[G] | None = ...) -> int: + """Get the number of persons for a given Role.""" + + +class Role(Protocol[G]): + """Interface representing a role of the group entities of a simulation.""" + + +class TaxBenefitSystem(Protocol): + """Interface representing a tax-benefit system.""" + + @property + @abstractmethod + def person_entity(self) -> SingleEntity: + """Get the person entity of the tax-benefit system.""" + + @person_entity.setter + @abstractmethod + def person_entity(self, person_entity: SingleEntity) -> None: + """Set the person entity of the tax-benefit system.""" + + @property + @abstractmethod + def variables(self) -> dict[str, V]: + """Get the variables of the tax-benefit system.""" + + def entities_by_singular(self) -> dict[str, E]: + """Get the singular form of the entities' keys.""" + + def entities_plural(self) -> Iterable[str]: + """Get the plural form of the entities' keys.""" + + def get_variable( + self, + __variable_name: str, + __check_existence: bool = ..., + ) -> V | None: + """Get a variable.""" + + def instantiate_entities( + self, + ) -> dict[str, Population[E]]: + """Instantiate the populations of each Entity.""" + + +class Variable(Protocol[T]): + """Interface representing a variable of a tax-benefit system.""" + + end: str + + def default_array(self, __array_size: int) -> Array[T]: + """Fill an array with the default value of the Variable.""" diff --git a/openfisca_core/taxbenefitsystems/tax_benefit_system.py b/openfisca_core/taxbenefitsystems/tax_benefit_system.py index 4fb2a15d22..8c48f64715 100644 --- a/openfisca_core/taxbenefitsystems/tax_benefit_system.py +++ b/openfisca_core/taxbenefitsystems/tax_benefit_system.py @@ -1,20 +1,30 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +from openfisca_core.types import ParameterNodeAtInstant + +import ast import copy +import functools import glob import importlib +import importlib.metadata +import importlib.util import inspect +import linecache import logging import os -import pkg_resources +import sys import traceback -import typing -from imp import find_module, load_module from openfisca_core import commons, periods, variables from openfisca_core.entities import Entity from openfisca_core.errors import VariableNameConflictError, VariableNotFoundError from openfisca_core.parameters import ParameterNode from openfisca_core.periods import Instant, Period -from openfisca_core.populations import Population, GroupPopulation +from openfisca_core.populations import GroupPopulation, Population from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable @@ -22,41 +32,47 @@ class TaxBenefitSystem: - """ - Represents the legislation. + """Represents the legislation. - It stores parameters (values defined for everyone) and variables (values defined for some given entity e.g. a person). + It stores parameters (values defined for everyone) and variables (values + defined for some given entity e.g. a person). - :param entities: Entities used by the tax benefit system. - :param string parameters: Directory containing the YAML parameter files. + Attributes: + parameters: Directory containing the YAML parameter files. + Args: + entities: Entities used by the tax benefit system. - .. attribute:: parameters - - :obj:`.ParameterNode` containing the legislation parameters """ + + person_entity: Entity + _base_tax_benefit_system = None - _parameters_at_instant_cache = None + _parameters_at_instant_cache: dict[Instant, ParameterNodeAtInstant] = {} person_key_plural = None preprocess_parameters = None baseline = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained. cache_blacklist = None decomposition_file_path = None - def __init__(self, entities): + def __init__(self, entities: Sequence[Entity]) -> None: # TODO: Currently: Don't use a weakref, because they are cleared by Paste (at least) at each call. - self.parameters = None - self._parameters_at_instant_cache = {} # weakref.WeakValueDictionary() - self.variables = {} - self.open_api_config = {} + self.parameters: ParameterNode | None = None + self.variables: dict[Any, Any] = {} + self.open_api_config: dict[Any, Any] = {} # Tax benefit systems are mutable, so entities (which need to know about our variables) can't be shared among them if entities is None or len(entities) == 0: - raise Exception("A tax and benefit sytem must have at least an entity.") + msg = "A tax and benefit sytem must have at least an entity." + raise Exception(msg) self.entities = [copy.copy(entity) for entity in entities] - self.person_entity = [entity for entity in self.entities if entity.is_person][0] - self.group_entities = [entity for entity in self.entities if not entity.is_person] + self.person_entity = next( + entity for entity in self.entities if entity.is_person + ) + self.group_entities = [ + entity for entity in self.entities if not entity.is_person + ] for entity in self.entities: - entity.tax_benefit_system = self + entity.set_tax_benefit_system(self) @property def base_tax_benefit_system(self): @@ -65,13 +81,15 @@ def base_tax_benefit_system(self): baseline = self.baseline if baseline is None: return self - self._base_tax_benefit_system = base_tax_benefit_system = baseline.base_tax_benefit_system + self._base_tax_benefit_system = base_tax_benefit_system = ( + baseline.base_tax_benefit_system + ) return base_tax_benefit_system def instantiate_entities(self): person = self.person_entity members = Population(person) - entities: typing.Dict[Entity.key, Entity] = {person.key: members} + entities: dict[Entity.key, Entity] = {person.key: members} for entity in self.group_entities: entities[entity.key] = GroupPopulation(entity, members) @@ -80,8 +98,8 @@ def instantiate_entities(self): # Deprecated method of constructing simulations, to be phased out in favor of SimulationBuilder def new_scenario(self): - class ScenarioAdapter(object): - def __init__(self, tax_benefit_system): + class ScenarioAdapter: + def __init__(self, tax_benefit_system) -> None: self.tax_benefit_system = tax_benefit_system def init_from_attributes(self, **attributes): @@ -91,10 +109,16 @@ def init_from_attributes(self, **attributes): def init_from_dict(self, dict): self.attributes = None self.dict = dict - self.period = dict.pop('period') + self.period = dict.pop("period") return self - def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = False, trace = False): + def new_simulation( + self, + debug=False, + opt_out_cache=False, + use_baseline=False, + trace=False, + ): # Legacy from scenarios, used in reforms tax_benefit_system = self.tax_benefit_system if use_baseline: @@ -106,13 +130,19 @@ def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = Fa builder = SimulationBuilder() if self.attributes: - variables = self.attributes.get('input_variables') or {} - period = self.attributes.get('period') + variables = self.attributes.get("input_variables") or {} + period = self.attributes.get("period") builder.set_default_period(period) - simulation = builder.build_from_variables(tax_benefit_system, variables) + simulation = builder.build_from_variables( + tax_benefit_system, + variables, + ) else: builder.set_default_period(self.period) - simulation = builder.build_from_entities(tax_benefit_system, self.dict) + simulation = builder.build_from_entities( + tax_benefit_system, + self.dict, + ) simulation.trace = trace simulation.debug = debug @@ -122,93 +152,134 @@ def new_simulation(self, debug = False, opt_out_cache = False, use_baseline = Fa return ScenarioAdapter(self) - def prefill_cache(self): + def prefill_cache(self) -> None: pass - def load_variable(self, variable_class, update = False): + def load_variable(self, variable_class, update=False): name = variable_class.__name__ # Check if a Variable with the same name is already registered. baseline_variable = self.get_variable(name) if baseline_variable and not update: + msg = f'Variable "{name}" is already defined. Use `update_variable` to replace it.' raise VariableNameConflictError( - 'Variable "{}" is already defined. Use `update_variable` to replace it.'.format(name)) + msg, + ) - variable = variable_class(baseline_variable = baseline_variable) + variable = variable_class(baseline_variable=baseline_variable) self.variables[variable.name] = variable return variable - def add_variable(self, variable): - """ - Adds an OpenFisca variable to the tax and benefit system. + def add_variable(self, variable: Variable) -> Variable: + """Adds an OpenFisca variable to the tax and benefit system. - :param .Variable variable: The variable to add. Must be a subclass of Variable. + Args: + variable: The variable to add. Must be a subclass of Variable. - :raises: :exc:`.VariableNameConflictError` if a variable with the same name have previously been added to the tax and benefit system. - """ - return self.load_variable(variable, update = False) + Raises: + openfisca_core.errors.VariableNameConflictError: if a variable with the same name have previously been added to the tax and benefit system. - def replace_variable(self, variable): """ - Replaces an existing OpenFisca variable in the tax and benefit system by a new one. + return self.load_variable(variable, update=False) + + def replace_variable(self, variable: Variable) -> None: + """Replaces an existing variable by a new one. The new variable must have the same name than the replaced one. - If no variable with the given name exists in the tax and benefit system, no error will be raised and the new variable will be simply added. + If no variable with the given name exists in the Tax-Benefit system, no + error will be raised and the new variable will be simply added. + + Args: + variable: The variable to replace. - :param Variable variable: New variable to add. Must be a subclass of Variable. """ name = variable.__name__ + if self.variables.get(name) is not None: del self.variables[name] - self.load_variable(variable, update = False) - def update_variable(self, variable): - """ - Updates an existing OpenFisca variable in the tax and benefit system. + self.load_variable(variable, update=False) + + def update_variable(self, variable: Variable) -> Variable: + """Update an existing variable in the Tax-Benefit system. - All attributes of the updated variable that are not explicitely overridden by the new ``variable`` will stay unchanged. + All attributes of the updated variable that are not explicitly + overridden by the new ``variable`` will stay unchanged. The new variable must have the same name than the updated one. - If no variable with the given name exists in the tax and benefit system, no error will be raised and the new variable will be simply added. + If no variable with the given name exists in the tax and benefit + system, no error will be raised and the new variable will be simply + added. - :param Variable variable: Variable to add. Must be a subclass of Variable. - """ - return self.load_variable(variable, update = True) + Args: + variable: Variable to add. Must be a subclass of Variable. + + Returns: + The added variable. - def add_variables_from_file(self, file_path): - """ - Adds all OpenFisca variables contained in a given file to the tax and benefit system. """ + return self.load_variable(variable, update=True) + + def add_variables_from_file(self, file_path) -> None: + """Adds all OpenFisca variables contained in a given file to the tax and benefit system.""" try: + source_file_path = file_path.replace( + self.get_package_metadata()["location"], + "", + ) + file_name = os.path.splitext(os.path.basename(file_path))[0] # As Python remembers loaded modules by name, in order to prevent collisions, we need to make sure that: # - Files with the same name, but located in different directories, have a different module names. Hence the file path hash in the module name. # - The same file, loaded by different tax and benefit systems, has distinct module names. Hence the `id(self)` in the module name. - module_name = '{}_{}_{}'.format(id(self), hash(os.path.abspath(file_path)), file_name) + module_name = f"{id(self)}_{hash(os.path.abspath(file_path))}_{file_name}" - module_directory = os.path.dirname(file_path) try: - module = load_module(module_name, *find_module(file_name, [module_directory])) + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + lines = linecache.getlines(file_path, module.__dict__) + source = "".join(lines) + tree = ast.parse(source) + defs = {i.name: i for i in tree.body if isinstance(i, ast.ClassDef)} + spec.loader.exec_module(module) + except NameError as e: - logging.error(str(e) + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: ") + logging.exception( + str(e) + + ": if this code used to work, this error might be due to a major change in OpenFisca-Core. Checkout the changelog to learn more: ", + ) raise - potential_variables = [getattr(module, item) for item in dir(module) if not item.startswith('__')] + potential_variables = [ + getattr(module, item) + for item in dir(module) + if not item.startswith("__") + ] for pot_variable in potential_variables: # We only want to get the module classes defined in this module (not imported) - if inspect.isclass(pot_variable) and issubclass(pot_variable, Variable) and pot_variable.__module__ == module_name: + if ( + inspect.isclass(pot_variable) + and issubclass(pot_variable, Variable) + and pot_variable.__module__ == module_name + ): + class_def = defs[pot_variable.__name__] + pot_variable.introspection_data = ( + source_file_path, + "".join(lines[class_def.lineno - 1 : class_def.end_lineno]), + class_def.lineno - 1, + ) self.add_variable(pot_variable) except Exception: - log.error('Unable to load OpenFisca variables from file "{}"'.format(file_path)) + log.exception(f'Unable to load OpenFisca variables from file "{file_path}"') raise - def add_variables_from_directory(self, directory): - """ - Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system. - """ + def add_variables_from_directory(self, directory) -> None: + """Recursively explores a directory, and adds all OpenFisca variables found there to the tax and benefit system.""" py_files = glob.glob(os.path.join(directory, "*.py")) for py_file in py_files: self.add_variables_from_file(py_file) @@ -216,20 +287,18 @@ def add_variables_from_directory(self, directory): for subdirectory in subdirectories: self.add_variables_from_directory(subdirectory) - def add_variables(self, *variables): - """ - Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. + def add_variables(self, *variables) -> None: + """Adds a list of OpenFisca Variables to the `TaxBenefitSystem`. See also :any:`add_variable` """ for variable in variables: self.add_variable(variable) - def load_extension(self, extension): - """ - Loads an extension to the tax and benefit system. + def load_extension(self, extension) -> None: + """Loads an extension to the tax and benefit system. - :param string extension: The extension to load. Can be an absolute path pointing to an extension directory, or the name of an OpenFisca extension installed as a pip package. + :param str extension: The extension to load. Can be an absolute path pointing to an extension directory, or the name of an OpenFisca extension installed as a pip package. """ # Load extension from installed pip package @@ -237,91 +306,130 @@ def load_extension(self, extension): package = importlib.import_module(extension) extension_directory = package.__path__[0] except ImportError: - message = os.linesep.join([traceback.format_exc(), - 'Error loading extension: `{}` is neither a directory, nor a package.'.format(extension), - 'Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.', - 'See more at .']) + message = os.linesep.join( + [ + traceback.format_exc(), + f"Error loading extension: `{extension}` is neither a directory, nor a package.", + "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", + "See more at .", + ], + ) raise ValueError(message) self.add_variables_from_directory(extension_directory) - param_dir = os.path.join(extension_directory, 'parameters') + param_dir = os.path.join(extension_directory, "parameters") if os.path.isdir(param_dir): - extension_parameters = ParameterNode(directory_path = param_dir) + extension_parameters = ParameterNode(directory_path=param_dir) self.parameters.merge(extension_parameters) - def apply_reform(self, reform_path): - """ - Generates a new tax and benefit system applying a reform to the tax and benefit system. + def apply_reform(self, reform_path: str) -> TaxBenefitSystem: + """Generates a new tax and benefit system applying a reform to the tax and benefit system. The current tax and benefit system is **not** mutated. - :param string reform_path: The reform to apply. Must respect the format *installed_package.sub_module.reform* + Args: + reform_path: The reform to apply. Must respect the format *installed_package.sub_module.reform* - :returns: A reformed tax and benefit system. + Returns: + TaxBenefitSystem: A reformed tax and benefit system. Example: - - >>> self.apply_reform('openfisca_france.reforms.inversion_revenus') + >>> self.apply_reform("openfisca_france.reforms.inversion_revenus") """ from openfisca_core.reforms import Reform + try: - reform_package, reform_name = reform_path.rsplit('.', 1) + reform_package, reform_name = reform_path.rsplit(".", 1) except ValueError: - raise ValueError('`{}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`'.format(reform_path)) + msg = f"`{reform_path}` does not seem to be a path pointing to a reform. A path looks like `some_country_package.reforms.some_reform.`" + raise ValueError( + msg, + ) try: reform_module = importlib.import_module(reform_package) except ImportError: - message = os.linesep.join([traceback.format_exc(), - 'Could not import `{}`.'.format(reform_package), - 'Are you sure of this reform module name? If so, look at the stack trace above to determine the origin of this error.']) + message = os.linesep.join( + [ + traceback.format_exc(), + f"Could not import `{reform_package}`.", + "Are you sure of this reform module name? If so, look at the stack trace above to determine the origin of this error.", + ], + ) raise ValueError(message) reform = getattr(reform_module, reform_name, None) if reform is None: - raise ValueError('{} has no attribute {}'.format(reform_package, reform_name)) + msg = f"{reform_package} has no attribute {reform_name}" + raise ValueError(msg) if not issubclass(reform, Reform): - raise ValueError('`{}` does not seem to be a valid Openfisca reform.'.format(reform_path)) + msg = f"`{reform_path}` does not seem to be a valid Openfisca reform." + raise ValueError( + msg, + ) return reform(self) - def get_variable(self, variable_name, check_existence = False): - """ - Get a variable from the tax and benefit system. + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> Variable | None: + """Get a variable from the tax and benefit system. :param variable_name: Name of the requested variable. :param check_existence: If True, raise an error if the requested variable does not exist. """ - variables = self.variables - found = variables.get(variable_name) - if not found and check_existence: - raise VariableNotFoundError(variable_name, self) - return found + variables: dict[str, Variable | None] = self.variables + variable: Variable | None = variables.get(variable_name) - def neutralize_variable(self, variable_name): - """ - Neutralizes an OpenFisca variable existing in the tax and benefit system. + if isinstance(variable, Variable): + return variable + + if not isinstance(variable, Variable) and not check_existence: + return variable + + raise VariableNotFoundError(variable_name, self) + + def neutralize_variable(self, variable_name: str) -> None: + """Neutralizes an OpenFisca variable existing in the tax and benefit system. A neutralized variable always returns its default value when computed. Trying to set inputs for a neutralized variable has no effect except raising a warning. """ - self.variables[variable_name] = variables.get_neutralized_variable(self.get_variable(variable_name)) + self.variables[variable_name] = variables.get_neutralized_variable( + self.get_variable(variable_name), + ) + + def annualize_variable( + self, + variable_name: str, + period: Period | None = None, + ) -> None: + check: bool + variable: Variable | None + annualised_variable: Variable + + check = bool(period) + variable = self.get_variable(variable_name, check) + + if variable is None: + raise VariableNotFoundError(variable_name, self) - def annualize_variable(self, variable_name: str, period: typing.Optional[Period] = None): - self.variables[variable_name] = variables.get_annualized_variable(self.get_variable(variable_name, period)) + annualised_variable = variables.get_annualized_variable(variable) - def load_parameters(self, path_to_yaml_dir): - """ - Loads the legislation parameter for a directory containing YAML parameters files. + self.variables[variable_name] = annualised_variable + + def load_parameters(self, path_to_yaml_dir) -> None: + """Loads the legislation parameter for a directory containing YAML parameters files. :param path_to_yaml_dir: Absolute path towards the YAML parameter directory. Example: + >>> self.load_parameters("/path/to/yaml/parameters/dir") - >>> self.load_parameters('/path/to/yaml/parameters/dir') """ - - parameters = ParameterNode('', directory_path = path_to_yaml_dir) + parameters = ParameterNode("", directory_path=path_to_yaml_dir) if self.preprocess_parameters is not None: parameters = self.preprocess_parameters(parameters) @@ -334,36 +442,48 @@ def _get_baseline_parameters_at_instant(self, instant): return self.get_parameters_at_instant(instant) return baseline._get_baseline_parameters_at_instant(instant) - def get_parameters_at_instant(self, instant): - """ - Get the parameters of the legislation at a given instant + @functools.lru_cache + def get_parameters_at_instant( + self, + instant: str | int | Period | Instant, + ) -> ParameterNodeAtInstant | None: + """Get the parameters of the legislation at a given instant. + + Args: + instant: :obj:`str` formatted "YYYY-MM-DD" or :class:`~openfisca_core.periods.Instant`. + + Returns: + The parameters of the legislation at a given instant. - :param instant: :obj:`str` of the format 'YYYY-MM-DD' or :class:`.Instant` instance. - :returns: The parameters of the legislation at a given instant. - :rtype: :class:`.ParameterNodeAtInstant` """ - if isinstance(instant, Period): - instant = instant.start + key: Instant | None + msg: str + + if isinstance(instant, Instant): + key = instant + + elif isinstance(instant, Period): + key = instant.start + elif isinstance(instant, (str, int)): - instant = periods.instant(instant) + key = periods.instant(instant) + else: - assert isinstance(instant, Instant), "Expected an Instant (e.g. Instant((2017, 1, 1)) ). Got: {}.".format(instant) + msg = f"Expected an Instant (e.g. Instant((2017, 1, 1)) ). Got: {key}." + raise AssertionError(msg) - parameters_at_instant = self._parameters_at_instant_cache.get(instant) - if parameters_at_instant is None and self.parameters is not None: - parameters_at_instant = self.parameters.get_at_instant(str(instant)) - self._parameters_at_instant_cache[instant] = parameters_at_instant - return parameters_at_instant + if self.parameters is None: + return None - def get_package_metadata(self): - """ - Gets metatada relative to the country package the tax and benefit system is built from. + return self.parameters.get_at_instant(key) - :returns: Country package metadata - :rtype: dict + def get_package_metadata(self) -> dict[str, str]: + """Gets metadata relative to the country package. - Example: + Returns: + A dictionary with the country package metadata + Example: >>> tax_benefit_system.get_package_metadata() >>> { >>> 'location': '/path/to/dir/containing/package', @@ -371,74 +491,93 @@ def get_package_metadata(self): >>> 'repository_url': 'https://github.com/openfisca/openfisca-france', >>> 'version': '17.2.0' >>> } + """ # Handle reforms if self.baseline: return self.baseline.get_package_metadata() - fallback_metadata = { - 'name': self.__class__.__name__, - 'version': '', - 'repository_url': '', - 'location': '', - } - module = inspect.getmodule(self) - if not module.__package__: - return fallback_metadata - package_name = module.__package__.split('.')[0] + try: - distribution = pkg_resources.get_distribution(package_name) - except pkg_resources.DistributionNotFound: - return fallback_metadata + source_file = inspect.getsourcefile(module) + package_name = module.__package__.split(".")[0] + distribution = importlib.metadata.distribution(package_name) + source_metadata = distribution.metadata + except Exception as e: + log.warning("Unable to load package metadata, exposing default metadata", e) + source_metadata = { + "Name": self.__class__.__name__, + "Version": "0.0.0", + "Home-page": "https://openfisca.org", + } - location = inspect.getsourcefile(module).split(package_name)[0].rstrip('/') + try: + source_file = inspect.getsourcefile(module) + location = source_file.split(package_name)[0].rstrip("/") + except Exception as e: + log.warning("Unable to load package source folder", e) + location = "_unknown_" + + repository_url = "" + if source_metadata.get("Project-URL"): # pyproject.toml metadata format + repository_url = next( + filter( + lambda url: url.startswith("Repository"), + source_metadata.get_all("Project-URL"), + ), + ).split("Repository, ")[-1] + else: # setup.py format + repository_url = source_metadata.get("Home-page") - home_page_metadatas = [ - metadata.split(':', 1)[1].strip(' ') - for metadata in distribution._get_metadata(distribution.PKG_INFO) if 'Home-page' in metadata - ] - repository_url = home_page_metadatas[0] if home_page_metadatas else '' return { - 'name': distribution.key, - 'version': distribution.version, - 'repository_url': repository_url, - 'location': location, - } + "name": source_metadata.get("Name").lower(), + "version": source_metadata.get("Version"), + "repository_url": repository_url, + "location": location, + } - def get_variables(self, entity = None): - """ - Gets all variables contained in a tax and benefit system. + def get_variables( + self, + entity: Entity | None = None, + ) -> dict[str, Variable]: + """Gets all variables contained in a tax and benefit system. - :param .Entity entity: If set, returns only the variable defined for the given entity. + Args: + entity: If set, returns the variable defined for the given entity. - :returns: A dictionnary, indexed by variable names. - :rtype: dict + Returns: + A dictionary, indexed by variable names. """ if not entity: return self.variables - else: - return { - variable_name: variable - for variable_name, variable in self.variables.items() - if variable.entity == entity - } + return { + variable_name: variable + for variable_name, variable in self.variables.items() + # TODO - because entities are copied (see constructor) they can't be compared + if variable.entity.key == entity.key + } def clone(self): new = commons.empty_clone(self) new_dict = new.__dict__ for key, value in self.__dict__.items(): - if key not in ('parameters', '_parameters_at_instant_cache', 'variables', 'open_api_config'): + if key not in ( + "parameters", + "_parameters_at_instant_cache", + "variables", + "open_api_config", + ): new_dict[key] = value - for entity in new_dict['entities']: - entity.tax_benefit_system = new + for entity in new_dict["entities"]: + entity.set_tax_benefit_system(new) - new_dict['parameters'] = self.parameters.clone() - new_dict['_parameters_at_instant_cache'] = {} - new_dict['variables'] = self.variables.copy() - new_dict['open_api_config'] = self.open_api_config.copy() + new_dict["parameters"] = self.parameters.clone() + new_dict["_parameters_at_instant_cache"] = {} + new_dict["variables"] = self.variables.copy() + new_dict["open_api_config"] = self.open_api_config.copy() return new def entities_plural(self): diff --git a/openfisca_core/taxscales/__init__.py b/openfisca_core/taxscales/__init__.py index 0364101d71..1911d20c56 100644 --- a/openfisca_core/taxscales/__init__.py +++ b/openfisca_core/taxscales/__init__.py @@ -23,13 +23,13 @@ from openfisca_core.errors import EmptyArgumentError # noqa: F401 -from .helpers import combine_tax_scales # noqa: F401 -from .tax_scale_like import TaxScaleLike # noqa: F401 -from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 -from .marginal_rate_tax_scale import MarginalRateTaxScale # noqa: F401 -from .linear_average_rate_tax_scale import LinearAverageRateTaxScale # noqa: F401 +from .abstract_rate_tax_scale import AbstractRateTaxScale # noqa: F401 from .abstract_tax_scale import AbstractTaxScale # noqa: F401 from .amount_tax_scale_like import AmountTaxScaleLike # noqa: F401 -from .abstract_rate_tax_scale import AbstractRateTaxScale # noqa: F401 +from .helpers import combine_tax_scales # noqa: F401 +from .linear_average_rate_tax_scale import LinearAverageRateTaxScale # noqa: F401 from .marginal_amount_tax_scale import MarginalAmountTaxScale # noqa: F401 +from .marginal_rate_tax_scale import MarginalRateTaxScale # noqa: F401 +from .rate_tax_scale_like import RateTaxScaleLike # noqa: F401 from .single_amount_tax_scale import SingleAmountTaxScale # noqa: F401 +from .tax_scale_like import TaxScaleLike # noqa: F401 diff --git a/openfisca_core/taxscales/abstract_rate_tax_scale.py b/openfisca_core/taxscales/abstract_rate_tax_scale.py index b9316273d1..9d828ed673 100644 --- a/openfisca_core/taxscales/abstract_rate_tax_scale.py +++ b/openfisca_core/taxscales/abstract_rate_tax_scale.py @@ -1,41 +1,42 @@ from __future__ import annotations import typing + import warnings -from openfisca_core.taxscales import RateTaxScaleLike +from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class AbstractRateTaxScale(RateTaxScaleLike): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ def __init__( - self, name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: str | None = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: message = [ "The 'AbstractRateTaxScale' class has been deprecated since", "version 34.7.0, and will be removed in the future.", - ] + ] - warnings.warn(" ".join(message), DeprecationWarning) + warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2) super().__init__(name, option, unit) def calc( - self, - tax_base: NumericalArray, - right: bool, - ) -> typing.NoReturn: + self, + tax_base: NumericalArray, + right: bool, + ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) diff --git a/openfisca_core/taxscales/abstract_tax_scale.py b/openfisca_core/taxscales/abstract_tax_scale.py index 9cbeeb7565..de9a6348c5 100644 --- a/openfisca_core/taxscales/abstract_tax_scale.py +++ b/openfisca_core/taxscales/abstract_tax_scale.py @@ -1,55 +1,54 @@ from __future__ import annotations import typing + import warnings -from openfisca_core.taxscales import TaxScaleLike +from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: import numpy - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class AbstractTaxScale(TaxScaleLike): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: numpy.int_ = None, - ) -> None: - + self, + name: str | None = None, + option: typing.Any = None, + unit: numpy.int16 = None, + ) -> None: message = [ "The 'AbstractTaxScale' class has been deprecated since", "version 34.7.0, and will be removed in the future.", - ] + ] - warnings.warn(" ".join(message), DeprecationWarning) + warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2) super().__init__(name, option, unit) def __repr__(self) -> typing.NoReturn: + msg = "Method '__repr__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__repr__' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) def calc( - self, - tax_base: NumericalArray, - right: bool, - ) -> typing.NoReturn: + self, + tax_base: NumericalArray, + right: bool, + ) -> typing.NoReturn: + msg = "Method 'calc' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method 'calc' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) def to_dict(self) -> typing.NoReturn: + msg = f"Method 'to_dict' is not implemented for {self.__class__.__name__}" raise NotImplementedError( - f"Method 'to_dict' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) diff --git a/openfisca_core/taxscales/amount_tax_scale_like.py b/openfisca_core/taxscales/amount_tax_scale_like.py index cfc0a6973f..1dc9acf4b3 100644 --- a/openfisca_core/taxscales/amount_tax_scale_like.py +++ b/openfisca_core/taxscales/amount_tax_scale_like.py @@ -1,26 +1,27 @@ +import typing + import abc import bisect import os -import typing from openfisca_core import tools -from openfisca_core.taxscales import TaxScaleLike + +from .tax_scale_like import TaxScaleLike class AmountTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of amount-based tax scales: single amount, + """Base class for various types of amount-based tax scales: single amount, marginal amount... """ - amounts: typing.List + amounts: list def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: typing.Optional[str] = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: super().__init__(name, option, unit) self.amounts = [] @@ -29,17 +30,16 @@ def __repr__(self) -> str: os.linesep.join( [ f"- threshold: {threshold}{os.linesep} amount: {amount}" - for (threshold, amount) - in zip(self.thresholds, self.amounts) - ] - ) - ) + for (threshold, amount) in zip(self.thresholds, self.amounts) + ], + ), + ) def add_bracket( - self, - threshold: int, - amount: typing.Union[int, float], - ) -> None: + self, + threshold: int, + amount: typing.Union[int, float], + ) -> None: if threshold in self.thresholds: i = self.thresholds.index(threshold) self.amounts[i] += amount @@ -52,6 +52,5 @@ def add_bracket( def to_dict(self) -> dict: return { str(threshold): self.amounts[index] - for index, threshold - in enumerate(self.thresholds) - } + for index, threshold in enumerate(self.thresholds) + } diff --git a/openfisca_core/taxscales/helpers.py b/openfisca_core/taxscales/helpers.py index 181fbfed36..687db41a3b 100644 --- a/openfisca_core/taxscales/helpers.py +++ b/openfisca_core/taxscales/helpers.py @@ -1,8 +1,9 @@ from __future__ import annotations -import logging import typing +import logging + from openfisca_core import taxscales log = logging.getLogger(__name__) @@ -14,21 +15,19 @@ def combine_tax_scales( - node: ParameterNodeAtInstant, - combined_tax_scales: TaxScales = None, - ) -> TaxScales: - """ - Combine all the MarginalRateTaxScales in the node into a single + node: ParameterNodeAtInstant, + combined_tax_scales: TaxScales = None, +) -> TaxScales: + """Combine all the MarginalRateTaxScales in the node into a single MarginalRateTaxScale. """ - name = next(iter(node or []), None) if name is None: return combined_tax_scales if combined_tax_scales is None: - combined_tax_scales = taxscales.MarginalRateTaxScale(name = name) + combined_tax_scales = taxscales.MarginalRateTaxScale(name=name) combined_tax_scales.add_bracket(0, 0) for child_name in node: @@ -41,6 +40,6 @@ def combine_tax_scales( log.info( f"Skipping {child_name} with value {child} " "because it is not a marginal rate tax scale", - ) + ) return combined_tax_scales diff --git a/openfisca_core/taxscales/linear_average_rate_tax_scale.py b/openfisca_core/taxscales/linear_average_rate_tax_scale.py index d1fe9c8094..ffccfc2205 100644 --- a/openfisca_core/taxscales/linear_average_rate_tax_scale.py +++ b/openfisca_core/taxscales/linear_average_rate_tax_scale.py @@ -1,26 +1,27 @@ from __future__ import annotations -import logging import typing +import logging + import numpy from openfisca_core import taxscales -from openfisca_core.taxscales import RateTaxScaleLike + +from .rate_tax_scale_like import RateTaxScaleLike log = logging.getLogger(__name__) if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class LinearAverageRateTaxScale(RateTaxScaleLike): - def calc( - self, - tax_base: NumericalArray, - right: bool = False, - ) -> numpy.float_: + self, + tax_base: NumericalArray, + right: bool = False, + ) -> numpy.float32: if len(self.rates) == 1: return tax_base * self.rates[0] @@ -28,17 +29,15 @@ def calc( tiled_thresholds = numpy.tile(self.thresholds, (len(tax_base), 1)) bracket_dummy = (tiled_base >= tiled_thresholds[:, :-1]) * ( - + tiled_base - < tiled_thresholds[:, 1:] - ) + +tiled_base < tiled_thresholds[:, 1:] + ) rates_array = numpy.array(self.rates) thresholds_array = numpy.array(self.thresholds) rate_slope = (rates_array[1:] - rates_array[:-1]) / ( - + thresholds_array[1:] - - thresholds_array[:-1] - ) + +thresholds_array[1:] - thresholds_array[:-1] + ) average_rate_slope = numpy.dot(bracket_dummy, rate_slope.T) @@ -49,17 +48,16 @@ def calc( log.info(f"average_rate_slope: {average_rate_slope}") return tax_base * ( - + bracket_average_start_rate - + (tax_base - bracket_threshold) - * average_rate_slope - ) + +bracket_average_start_rate + + (tax_base - bracket_threshold) * average_rate_slope + ) def to_marginal(self) -> taxscales.MarginalRateTaxScale: marginal_tax_scale = taxscales.MarginalRateTaxScale( - name = self.name, - option = self.option, - unit = self.unit, - ) + name=self.name, + option=self.option, + unit=self.unit, + ) previous_i = 0 previous_threshold = 0 @@ -70,7 +68,7 @@ def to_marginal(self) -> taxscales.MarginalRateTaxScale: marginal_tax_scale.add_bracket( previous_threshold, (i - previous_i) / (threshold - previous_threshold), - ) + ) previous_i = i previous_threshold = threshold diff --git a/openfisca_core/taxscales/marginal_amount_tax_scale.py b/openfisca_core/taxscales/marginal_amount_tax_scale.py index 348f2445c0..aa96bff57b 100644 --- a/openfisca_core/taxscales/marginal_amount_tax_scale.py +++ b/openfisca_core/taxscales/marginal_amount_tax_scale.py @@ -4,31 +4,31 @@ import numpy -from openfisca_core.taxscales import AmountTaxScaleLike +from .amount_tax_scale_like import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class MarginalAmountTaxScale(AmountTaxScaleLike): - def calc( - self, - tax_base: NumericalArray, - right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the sum of + self, + tax_base: NumericalArray, + right: bool = False, + ) -> numpy.float32: + """Matches the input amount to a set of brackets and returns the sum of cell values from the lowest bracket to the one containing the input. """ base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T thresholds1 = numpy.tile( - numpy.hstack((self.thresholds, numpy.inf)), (len(tax_base), 1) - ) + numpy.hstack((self.thresholds, numpy.inf)), + (len(tax_base), 1), + ) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 - ) + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, + ) return numpy.dot(self.amounts, a.T > 0) diff --git a/openfisca_core/taxscales/marginal_rate_tax_scale.py b/openfisca_core/taxscales/marginal_rate_tax_scale.py index 38331e0bb8..803a5f8547 100644 --- a/openfisca_core/taxscales/marginal_rate_tax_scale.py +++ b/openfisca_core/taxscales/marginal_rate_tax_scale.py @@ -1,44 +1,44 @@ from __future__ import annotations +import typing + import bisect import itertools -import typing import numpy from openfisca_core import taxscales -from openfisca_core.taxscales import RateTaxScaleLike + +from .rate_tax_scale_like import RateTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class MarginalRateTaxScale(RateTaxScaleLike): - def add_tax_scale(self, tax_scale: RateTaxScaleLike) -> None: # So as not to have problems with empty scales - if (len(tax_scale.thresholds) > 0): + if len(tax_scale.thresholds) > 0: for threshold_low, threshold_high, rate in zip( - tax_scale.thresholds[:-1], - tax_scale.thresholds[1:], - tax_scale.rates, - ): + tax_scale.thresholds[:-1], + tax_scale.thresholds[1:], + tax_scale.rates, + ): self.combine_bracket(rate, threshold_low, threshold_high) # To process the last threshold self.combine_bracket( tax_scale.rates[-1], tax_scale.thresholds[-1], - ) + ) def calc( - self, - tax_base: NumericalArray, - factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the tax amount for the given tax bases by applying a taxscale. + self, + tax_base: NumericalArray, + factor: float = 1.0, + round_base_decimals: int | None = None, + ) -> numpy.float32: + """Compute the tax amount for the given tax bases by applying a taxscale. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of the taxscale. @@ -67,31 +67,31 @@ def calc( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - factor + numpy.finfo(numpy.float_).eps, - numpy.array(self.thresholds + [numpy.inf]), - ) + factor + numpy.finfo(numpy.float64).eps, + numpy.array([*self.thresholds, numpy.inf]), + ) if round_base_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_base_decimals) + thresholds1 = numpy.round(thresholds1, round_base_decimals) a = numpy.maximum( - numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0 - ) + numpy.minimum(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], + 0, + ) if round_base_decimals is None: return numpy.dot(self.rates, a.T) - else: - r = numpy.tile(self.rates, (len(tax_base), 1)) - b = numpy.round_(a, round_base_decimals) - return numpy.round_(r * b, round_base_decimals).sum(axis = 1) + r = numpy.tile(self.rates, (len(tax_base), 1)) + b = numpy.round(a, round_base_decimals) + return numpy.round(r * b, round_base_decimals).sum(axis=1) def combine_bracket( - self, - rate: typing.Union[int, float], - threshold_low: int = 0, - threshold_high: typing.Union[int, bool] = False, - ) -> None: + self, + rate: int | float, + threshold_low: int = 0, + threshold_high: int | bool = False, + ) -> None: # Insert threshold_low and threshold_high without modifying rates if threshold_low not in self.thresholds: index = bisect.bisect_right(self.thresholds, threshold_low) - 1 @@ -115,13 +115,12 @@ def combine_bracket( i += 1 def marginal_rates( - self, - tax_base: NumericalArray, - factor: float = 1.0, - round_base_decimals: typing.Optional[int] = None, - ) -> numpy.float_: - """ - Compute the marginal tax rates relevant for the given tax bases. + self, + tax_base: NumericalArray, + factor: float = 1.0, + round_base_decimals: int | None = None, + ) -> numpy.float32: + """Compute the marginal tax rates relevant for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds of a tax scale. @@ -144,13 +143,72 @@ def marginal_rates( tax_base, factor, round_base_decimals, - ) + ) return numpy.array(self.rates)[bracket_indices] - def inverse(self) -> MarginalRateTaxScale: + def rate_from_bracket_indice( + self, + bracket_indice: numpy.int16, + ) -> numpy.float32: + """Compute the relevant tax rates for the given bracket indices. + + :param: ndarray bracket_indice: Array of the bracket indices. + + :returns: Floating array with relevant tax rates + for the given bracket indices. + + For instance: + + >>> import numpy + >>> tax_scale = MarginalRateTaxScale() + >>> tax_scale.add_bracket(0, 0) + >>> tax_scale.add_bracket(200, 0.1) + >>> tax_scale.add_bracket(500, 0.25) + >>> tax_base = numpy.array([50, 1_000, 250]) + >>> bracket_indice = tax_scale.bracket_indices(tax_base) + >>> tax_scale.rate_from_bracket_indice(bracket_indice) + array([0. , 0.25, 0.1 ]) + """ + if bracket_indice.max() > len(self.rates) - 1: + msg = ( + f"bracket_indice parameter ({bracket_indice}) " + f"contains one or more bracket indice which is unavailable " + f"inside current {self.__class__.__name__} :\n" + f"{self}" + ) + raise IndexError( + msg, + ) + + return numpy.array(self.rates)[bracket_indice] + + def rate_from_tax_base( + self, + tax_base: NumericalArray, + ) -> numpy.float32: + """Compute the relevant tax rates for the given tax bases. + + :param: ndarray tax_base: Array of the tax bases. + + :returns: Floating array with relevant tax rates + for the given tax bases. + + For instance: + + >>> import numpy + >>> tax_scale = MarginalRateTaxScale() + >>> tax_scale.add_bracket(0, 0) + >>> tax_scale.add_bracket(200, 0.1) + >>> tax_scale.add_bracket(500, 0.25) + >>> tax_base = numpy.array([1_000, 50, 450]) + >>> tax_scale.rate_from_tax_base(tax_base) + array([0.25, 0. , 0.1 ]) """ - Returns a new instance of MarginalRateTaxScale. + return self.rate_from_bracket_indice(self.bracket_indices(tax_base)) + + def inverse(self) -> MarginalRateTaxScale: + """Returns a new instance of MarginalRateTaxScale. Invert a taxscale: @@ -176,10 +234,10 @@ def inverse(self) -> MarginalRateTaxScale: # Actually 1 / (1 - global_rate) inverse = self.__class__( - name = str(self.name) + "'", - option = self.option, - unit = self.unit, - ) + name=str(self.name) + "'", + option=self.option, + unit=self.unit, + ) for threshold, rate in zip(self.thresholds, self.rates): if threshold == 0: @@ -202,10 +260,10 @@ def scale_tax_scales(self, factor: float) -> MarginalRateTaxScale: def to_average(self) -> taxscales.LinearAverageRateTaxScale: average_tax_scale = taxscales.LinearAverageRateTaxScale( - name = self.name, - option = self.option, - unit = self.unit, - ) + name=self.name, + option=self.option, + unit=self.unit, + ) average_tax_scale.add_bracket(0, 0) @@ -215,10 +273,10 @@ def to_average(self) -> taxscales.LinearAverageRateTaxScale: previous_rate = self.rates[0] for threshold, rate in itertools.islice( - zip(self.thresholds, self.rates), - 1, - None, - ): + zip(self.thresholds, self.rates), + 1, + None, + ): i += previous_rate * (threshold - previous_threshold) average_tax_scale.add_bracket(threshold, i / threshold) previous_threshold = threshold diff --git a/openfisca_core/taxscales/rate_tax_scale_like.py b/openfisca_core/taxscales/rate_tax_scale_like.py index 824a94debe..288226f11e 100644 --- a/openfisca_core/taxscales/rate_tax_scale_like.py +++ b/openfisca_core/taxscales/rate_tax_scale_like.py @@ -1,34 +1,35 @@ from __future__ import annotations +import typing + import abc import bisect import os -import typing import numpy from openfisca_core import tools from openfisca_core.errors import EmptyArgumentError -from openfisca_core.taxscales import TaxScaleLike + +from .tax_scale_like import TaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class RateTaxScaleLike(TaxScaleLike, abc.ABC): - """ - Base class for various types of rate-based tax scales: marginal rate, + """Base class for various types of rate-based tax scales: marginal rate, linear average rate... """ - rates: typing.List + rates: list def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: str | None = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: super().__init__(name, option, unit) self.rates = [] @@ -37,17 +38,16 @@ def __repr__(self) -> str: os.linesep.join( [ f"- threshold: {threshold}{os.linesep} rate: {rate}" - for (threshold, rate) - in zip(self.thresholds, self.rates) - ] - ) - ) + for (threshold, rate) in zip(self.thresholds, self.rates) + ], + ), + ) def add_bracket( - self, - threshold: typing.Union[int, float], - rate: typing.Union[int, float], - ) -> None: + self, + threshold: int | float, + rate: int | float, + ) -> None: if threshold in self.thresholds: i = self.thresholds.index(threshold) self.rates[i] += rate @@ -58,11 +58,11 @@ def add_bracket( self.rates.insert(i, rate) def multiply_rates( - self, - factor: float, - inplace: bool = True, - new_name: typing.Optional[str] = None, - ) -> RateTaxScaleLike: + self, + factor: float, + inplace: bool = True, + new_name: str | None = None, + ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -73,9 +73,9 @@ def multiply_rates( new_tax_scale = self.__class__( new_name or self.name, - option = self.option, - unit = self.unit, - ) + option=self.option, + unit=self.unit, + ) for threshold, rate in zip(self.thresholds, self.rates): new_tax_scale.thresholds.append(threshold) @@ -84,12 +84,12 @@ def multiply_rates( return new_tax_scale def multiply_thresholds( - self, - factor: float, - decimals: typing.Optional[int] = None, - inplace: bool = True, - new_name: typing.Optional[str] = None, - ) -> RateTaxScaleLike: + self, + factor: float, + decimals: int | None = None, + inplace: bool = True, + new_name: str | None = None, + ) -> RateTaxScaleLike: if inplace: assert new_name is None @@ -97,8 +97,8 @@ def multiply_thresholds( if decimals is not None: self.thresholds[i] = numpy.around( threshold * factor, - decimals = decimals, - ) + decimals=decimals, + ) else: self.thresholds[i] = threshold * factor @@ -107,15 +107,15 @@ def multiply_thresholds( new_tax_scale = self.__class__( new_name or self.name, - option = self.option, - unit = self.unit, - ) + option=self.option, + unit=self.unit, + ) for threshold, rate in zip(self.thresholds, self.rates): if decimals is not None: new_tax_scale.thresholds.append( - numpy.around(threshold * factor, decimals = decimals), - ) + numpy.around(threshold * factor, decimals=decimals), + ) else: new_tax_scale.thresholds.append(threshold * factor) @@ -124,13 +124,12 @@ def multiply_thresholds( return new_tax_scale def bracket_indices( - self, - tax_base: NumericalArray, - factor: float = 1.0, - round_decimals: typing.Optional[int] = None, - ) -> numpy.int_: - """ - Compute the relevant bracket indices for the given tax bases. + self, + tax_base: NumericalArray, + factor: float = 1.0, + round_decimals: int | None = None, + ) -> numpy.int32: + """Compute the relevant bracket indices for the given tax bases. :param ndarray tax_base: Array of the tax bases. :param float factor: Factor to apply to the thresholds. @@ -148,14 +147,13 @@ def bracket_indices( >>> tax_scale.bracket_indices(tax_base) [0, 1] """ - if not numpy.size(numpy.array(self.thresholds)): raise EmptyArgumentError( self.__class__.__name__, "bracket_indices", "self.thresholds", self.thresholds, - ) + ) if not numpy.size(numpy.asarray(tax_base)): raise EmptyArgumentError( @@ -163,7 +161,7 @@ def bracket_indices( "bracket_indices", "tax_base", tax_base, - ) + ) base1 = numpy.tile(tax_base, (len(self.thresholds), 1)).T factor = numpy.ones(len(tax_base)) * factor @@ -176,18 +174,42 @@ def bracket_indices( # # numpy.finfo(float_).eps thresholds1 = numpy.outer( - + factor - + numpy.finfo(numpy.float_).eps, numpy.array(self.thresholds) - ) + +factor + numpy.finfo(numpy.float64).eps, + numpy.array(self.thresholds), + ) if round_decimals is not None: - thresholds1 = numpy.round_(thresholds1, round_decimals) + thresholds1 = numpy.round(thresholds1, round_decimals) + + return (base1 - thresholds1 >= 0).sum(axis=1) - 1 + + def threshold_from_tax_base( + self, + tax_base: NumericalArray, + ) -> NumericalArray: + """Compute the relevant thresholds for the given tax bases. - return (base1 - thresholds1 >= 0).sum(axis = 1) - 1 + :param: ndarray tax_base: Array of the tax bases. + + :returns: Floating array with relevant thresholds + for the given tax bases. + + For instance: + + >>> import numpy + >>> from openfisca_core import taxscales + >>> tax_scale = taxscales.MarginalRateTaxScale() + >>> tax_scale.add_bracket(0, 0) + >>> tax_scale.add_bracket(200, 0.1) + >>> tax_scale.add_bracket(500, 0.25) + >>> tax_base = numpy.array([450, 1_150, 10]) + >>> tax_scale.threshold_from_tax_base(tax_base) + array([200, 500, 0]) + """ + return numpy.array(self.thresholds)[self.bracket_indices(tax_base)] def to_dict(self) -> dict: return { str(threshold): self.rates[index] - for index, threshold - in enumerate(self.thresholds) - } + for index, threshold in enumerate(self.thresholds) + } diff --git a/openfisca_core/taxscales/single_amount_tax_scale.py b/openfisca_core/taxscales/single_amount_tax_scale.py index bdfee48010..1c8cf69a32 100644 --- a/openfisca_core/taxscales/single_amount_tax_scale.py +++ b/openfisca_core/taxscales/single_amount_tax_scale.py @@ -7,36 +7,26 @@ from openfisca_core.taxscales import AmountTaxScaleLike if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + NumericalArray = typing.Union[numpy.int32, numpy.float32] class SingleAmountTaxScale(AmountTaxScaleLike): - def calc( - self, - tax_base: NumericalArray, - right: bool = False, - ) -> numpy.float_: - """ - Matches the input amount to a set of brackets and returns the single + self, + tax_base: NumericalArray, + right: bool = False, + ) -> numpy.float32: + """Matches the input amount to a set of brackets and returns the single cell value that fits within that bracket. """ - guarded_thresholds = numpy.array( - [-numpy.inf] - + self.thresholds - + [numpy.inf] - ) + guarded_thresholds = numpy.array([-numpy.inf, *self.thresholds, numpy.inf]) bracket_indices = numpy.digitize( tax_base, guarded_thresholds, - right = right, - ) - - guarded_amounts = numpy.array( - [0] - + self.amounts - + [0] - ) + right=right, + ) + + guarded_amounts = numpy.array([0, *self.amounts, 0]) return guarded_amounts[bracket_indices - 1] diff --git a/openfisca_core/taxscales/tax_scale_like.py b/openfisca_core/taxscales/tax_scale_like.py index 8177ee0505..e8680b9f8f 100644 --- a/openfisca_core/taxscales/tax_scale_like.py +++ b/openfisca_core/taxscales/tax_scale_like.py @@ -1,67 +1,64 @@ from __future__ import annotations -import abc -import copy import typing -import numpy +import abc +import copy from openfisca_core import commons if typing.TYPE_CHECKING: - NumericalArray = typing.Union[numpy.int_, numpy.float_] + import numpy + + NumericalArray = typing.Union[numpy.int32, numpy.float32] class TaxScaleLike(abc.ABC): - """ - Base class for various types of tax scales: amount-based tax scales, + """Base class for various types of tax scales: amount-based tax scales, rate-based tax scales... """ - name: typing.Optional[str] + name: str | None option: typing.Any unit: typing.Any - thresholds: typing.List + thresholds: list @abc.abstractmethod def __init__( - self, - name: typing.Optional[str] = None, - option: typing.Any = None, - unit: typing.Any = None, - ) -> None: + self, + name: str | None = None, + option: typing.Any = None, + unit: typing.Any = None, + ) -> None: self.name = name or "Untitled TaxScale" self.option = option self.unit = unit self.thresholds = [] def __eq__(self, _other: object) -> typing.NoReturn: + msg = "Method '__eq__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__eq__' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) def __ne__(self, _other: object) -> typing.NoReturn: + msg = "Method '__ne__' is not implemented for " f"{self.__class__.__name__}" raise NotImplementedError( - "Method '__ne__' is not implemented for " - f"{self.__class__.__name__}", - ) + msg, + ) @abc.abstractmethod - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... @abc.abstractmethod def calc( - self, - tax_base: NumericalArray, - right: bool, - ) -> numpy.float_: - ... + self, + tax_base: NumericalArray, + right: bool, + ) -> numpy.float32: ... @abc.abstractmethod - def to_dict(self) -> dict: - ... + def to_dict(self) -> dict: ... def copy(self) -> typing.Any: new = commons.empty_clone(self) diff --git a/openfisca_core/tools/__init__.py b/openfisca_core/tools/__init__.py index 9b1dd2cc5d..952dca6ebd 100644 --- a/openfisca_core/tools/__init__.py +++ b/openfisca_core/tools/__init__.py @@ -1,65 +1,70 @@ -# -*- coding: utf-8 -*- - - import os -import numexpr - +from openfisca_core import commons from openfisca_core.indexed_enums import EnumArray -def assert_near(value, target_value, absolute_error_margin = None, message = '', relative_error_margin = None): - ''' - - :param value: Value returned by the test - :param target_value: Value that the test should return to pass - :param absolute_error_margin: Absolute error margin authorized - :param message: Error message to be displayed if the test fails - :param relative_error_margin: Relative error margin authorized - - Limit : This function cannot be used to assert near periods. +def assert_near( + value, + target_value, + absolute_error_margin=None, + message="", + relative_error_margin=None, +): + """:param value: Value returned by the test + :param target_value: Value that the test should return to pass + :param absolute_error_margin: Absolute error margin authorized + :param message: Error message to be displayed if the test fails + :param relative_error_margin: Relative error margin authorized - ''' + Limit : This function cannot be used to assert near periods. - import numpy as np + """ + import numpy if absolute_error_margin is None and relative_error_margin is None: absolute_error_margin = 0 - if not isinstance(value, np.ndarray): - value = np.array(value) + if not isinstance(value, numpy.ndarray): + value = numpy.array(value) if isinstance(value, EnumArray): return assert_enum_equals(value, target_value, message) - if np.issubdtype(value.dtype, np.datetime64): - target_value = np.array(target_value, dtype = value.dtype) + if numpy.issubdtype(value.dtype, numpy.datetime64): + target_value = numpy.array(target_value, dtype=value.dtype) assert_datetime_equals(value, target_value, message) if isinstance(target_value, str): - target_value = eval_expression(target_value) + target_value = commons.eval_expression(target_value) - target_value = np.array(target_value).astype(np.float32) + target_value = numpy.array(target_value).astype(numpy.float32) - value = np.array(value).astype(np.float32) + value = numpy.array(value).astype(numpy.float32) diff = abs(target_value - value) if absolute_error_margin is not None: - assert (diff <= absolute_error_margin).all(), \ - '{}{} differs from {} with an absolute margin {} > {}'.format(message, value, target_value, - diff, absolute_error_margin) + assert ( + diff <= absolute_error_margin + ).all(), f"{message}{value} differs from {target_value} with an absolute margin {diff} > {absolute_error_margin}" if relative_error_margin is not None: - assert (diff <= abs(relative_error_margin * target_value)).all(), \ - '{}{} differs from {} with a relative margin {} > {}'.format(message, value, target_value, - diff, abs(relative_error_margin * target_value)) + assert ( + diff <= abs(relative_error_margin * target_value) + ).all(), f"{message}{value} differs from {target_value} with a relative margin {diff} > {abs(relative_error_margin * target_value)}" + return None + return None -def assert_datetime_equals(value, target_value, message = ''): - assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) +def assert_datetime_equals(value, target_value, message="") -> None: + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." -def assert_enum_equals(value, target_value, message = ''): +def assert_enum_equals(value, target_value, message="") -> None: value = value.decode_to_str() - assert (value == target_value).all(), '{}{} differs from {}.'.format(message, value, target_value) + assert ( + value == target_value + ).all(), f"{message}{value} differs from {target_value}." def indent(text): - return " {}".format(text.replace(os.linesep, "{} ".format(os.linesep))) + return " {}".format(text.replace(os.linesep, f"{os.linesep} ")) def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): @@ -68,18 +73,16 @@ def get_trace_tool_link(scenario, variables, api_url, trace_tool_url): scenario_json = scenario.to_json() simulation_json = { - 'scenarios': [scenario_json], - 'variables': variables, - } - url = trace_tool_url + '?' + urllib.urlencode({ - 'simulation': json.dumps(simulation_json), - 'api_url': api_url, - }) - return url - - -def eval_expression(expression): - try: - return numexpr.evaluate(expression) - except (KeyError, TypeError): - return expression + "scenarios": [scenario_json], + "variables": variables, + } + return ( + trace_tool_url + + "?" + + urllib.urlencode( + { + "simulation": json.dumps(simulation_json), + "api_url": api_url, + }, + ) + ) diff --git a/openfisca_core/tools/simulation_dumper.py b/openfisca_core/tools/simulation_dumper.py index 4b5907c0ff..84898165fd 100644 --- a/openfisca_core/tools/simulation_dumper.py +++ b/openfisca_core/tools/simulation_dumper.py @@ -1,19 +1,14 @@ -# -*- coding: utf-8 -*- - - import os -import numpy as np +import numpy -from openfisca_core.simulations import Simulation from openfisca_core.data_storage import OnDiskStorage -from openfisca_core.periods import ETERNITY +from openfisca_core.periods import DateUnit +from openfisca_core.simulations import Simulation -def dump_simulation(simulation, directory): - """ - Write simulation data to directory, so that it can be restored later. - """ +def dump_simulation(simulation, directory) -> None: + """Write simulation data to directory, so that it can be restored later.""" parent_directory = os.path.abspath(os.path.join(directory, os.pardir)) if not os.path.isdir(parent_directory): # To deal with reforms os.mkdir(parent_directory) @@ -21,7 +16,8 @@ def dump_simulation(simulation, directory): os.mkdir(directory) if os.listdir(directory): - raise ValueError("Directory '{}' is not empty".format(directory)) + msg = f"Directory '{directory}' is not empty" + raise ValueError(msg) entities_dump_dir = os.path.join(directory, "__entities__") os.mkdir(entities_dump_dir) @@ -36,10 +32,11 @@ def dump_simulation(simulation, directory): def restore_simulation(directory, tax_benefit_system, **kwargs): - """ - Restore simulation from directory - """ - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) + """Restore simulation from directory.""" + simulation = Simulation( + tax_benefit_system, + tax_benefit_system.instantiate_entities(), + ) entities_dump_dir = os.path.join(directory, "__entities__") for population in simulation.populations.values(): @@ -53,75 +50,84 @@ def restore_simulation(directory, tax_benefit_system, **kwargs): _restore_entity(population, entities_dump_dir) population.count = person_count - variables_to_restore = (variable for variable in os.listdir(directory) if variable != "__entities__") + variables_to_restore = ( + variable for variable in os.listdir(directory) if variable != "__entities__" + ) for variable in variables_to_restore: _restore_holder(simulation, variable, directory) return simulation -def _dump_holder(holder, directory): - disk_storage = holder.create_disk_storage(directory, preserve = True) +def _dump_holder(holder, directory) -> None: + disk_storage = holder.create_disk_storage(directory, preserve=True) for period in holder.get_known_periods(): value = holder.get_array(period) disk_storage.put(value, period) -def _dump_entity(population, directory): +def _dump_entity(population, directory) -> None: path = os.path.join(directory, population.entity.key) os.mkdir(path) - np.save(os.path.join(path, "id.npy"), population.ids) + numpy.save(os.path.join(path, "id.npy"), population.ids) if population.entity.is_person: return - np.save(os.path.join(path, "members_position.npy"), population.members_position) - np.save(os.path.join(path, "members_entity_id.npy"), population.members_entity_id) + numpy.save(os.path.join(path, "members_position.npy"), population.members_position) + numpy.save( + os.path.join(path, "members_entity_id.npy"), population.members_entity_id + ) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - encoded_roles = np.int64(0) + encoded_roles = numpy.int16(0) else: - encoded_roles = np.select( + encoded_roles = numpy.select( [population.members_role == role for role in flattened_roles], [role.key for role in flattened_roles], - ) - np.save(os.path.join(path, "members_role.npy"), encoded_roles) + ) + numpy.save(os.path.join(path, "members_role.npy"), encoded_roles) def _restore_entity(population, directory): path = os.path.join(directory, population.entity.key) - population.ids = np.load(os.path.join(path, "id.npy")) + population.ids = numpy.load(os.path.join(path, "id.npy")) if population.entity.is_person: - return + return None - population.members_position = np.load(os.path.join(path, "members_position.npy")) - population.members_entity_id = np.load(os.path.join(path, "members_entity_id.npy")) - encoded_roles = np.load(os.path.join(path, "members_role.npy")) + population.members_position = numpy.load(os.path.join(path, "members_position.npy")) + population.members_entity_id = numpy.load( + os.path.join(path, "members_entity_id.npy") + ) + encoded_roles = numpy.load(os.path.join(path, "members_role.npy")) flattened_roles = population.entity.flattened_roles if len(flattened_roles) == 0: - population.members_role = np.int64(0) + population.members_role = numpy.int16(0) else: - population.members_role = np.select( + population.members_role = numpy.select( [encoded_roles == role.key for role in flattened_roles], - [role for role in flattened_roles], - ) + list(flattened_roles), + ) person_count = len(population.members_entity_id) population.count = max(population.members_entity_id) + 1 return person_count -def _restore_holder(simulation, variable, directory): +def _restore_holder(simulation, variable, directory) -> None: storage_dir = os.path.join(directory, variable) - is_variable_eternal = simulation.tax_benefit_system.get_variable(variable).definition_period == ETERNITY + is_variable_eternal = ( + simulation.tax_benefit_system.get_variable(variable).definition_period + == DateUnit.ETERNITY + ) disk_storage = OnDiskStorage( storage_dir, - is_eternal = is_variable_eternal, - preserve_storage_dir = True - ) + is_eternal=is_variable_eternal, + preserve_storage_dir=True, + ) disk_storage.restore() holder = simulation.get_holder(variable) diff --git a/openfisca_core/tools/test_runner.py b/openfisca_core/tools/test_runner.py index 1e01bc0fca..fcb5572b79 100644 --- a/openfisca_core/tools/test_runner.py +++ b/openfisca_core/tools/test_runner.py @@ -1,22 +1,83 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations -import warnings -import sys +from collections.abc import Sequence +from typing import Any +from typing_extensions import Literal, TypedDict + +from openfisca_core.types import TaxBenefitSystem + +import dataclasses import os -import traceback +import pathlib +import sys import textwrap -from typing import Dict, List +import traceback +import warnings import pytest -from openfisca_core.tools import assert_near -from openfisca_core.simulation_builder import SimulationBuilder from openfisca_core.errors import SituationParsingError, VariableNotFound +from openfisca_core.simulation_builder import SimulationBuilder +from openfisca_core.tools import assert_near from openfisca_core.warnings import LibYAMLWarning +class Options(TypedDict, total=False): + aggregate: bool + ignore_variables: Sequence[str] | None + max_depth: int | None + name_filter: str | None + only_variables: Sequence[str] | None + pdb: bool + performance_graph: bool + performance_tables: bool + verbose: bool + + +@dataclasses.dataclass(frozen=True) +class ErrorMargin: + __root__: dict[str | Literal["default"], float | None] + + def __getitem__(self, key: str) -> float | None: + if key in self.__root__: + return self.__root__[key] + + return self.__root__["default"] + + +@dataclasses.dataclass +class Test: + absolute_error_margin: ErrorMargin + relative_error_margin: ErrorMargin + name: str = "" + input: dict[str, float | dict[str, float]] = dataclasses.field(default_factory=dict) + output: dict[str, float | dict[str, float]] | None = None + period: str | None = None + reforms: Sequence[str] = dataclasses.field(default_factory=list) + keywords: Sequence[str] | None = None + extensions: Sequence[str] = dataclasses.field(default_factory=list) + description: str | None = None + max_spiral_loops: int | None = None + + +def build_test(params: dict[str, Any]) -> Test: + for key in ["absolute_error_margin", "relative_error_margin"]: + value = params.get(key) + + if value is None: + value = {"default": None} + + elif isinstance(value, (float, int, str)): + value = {"default": float(value)} + + params[key] = ErrorMargin(value) + + return Test(**params) + + def import_yaml(): import yaml + try: from yaml import CLoader as Loader except ImportError: @@ -24,33 +85,55 @@ def import_yaml(): "libyaml is not installed in your environment.", "This can make your test suite slower to run. Once you have installed libyaml, ", "run 'pip uninstall pyyaml && pip install pyyaml --no-cache-dir'", - "so that it is used in your Python environment." - ] - warnings.warn(" ".join(message), LibYAMLWarning) + "so that it is used in your Python environment.", + ] + warnings.warn(" ".join(message), LibYAMLWarning, stacklevel=2) from yaml import SafeLoader as Loader return yaml, Loader -TEST_KEYWORDS = {'absolute_error_margin', 'description', 'extensions', 'ignore_variables', 'input', 'keywords', 'max_spiral_loops', 'name', 'only_variables', 'output', 'period', 'reforms', 'relative_error_margin'} +TEST_KEYWORDS = { + "absolute_error_margin", + "description", + "extensions", + "ignore_variables", + "input", + "keywords", + "max_spiral_loops", + "name", + "only_variables", + "output", + "period", + "reforms", + "relative_error_margin", +} yaml, Loader = import_yaml() -_tax_benefit_system_cache: Dict = {} +_tax_benefit_system_cache: dict = {} +options: Options = Options() -def run_tests(tax_benefit_system, paths, options = None): - """ - Runs all the YAML tests contained in a file or a directory. - If `path` is a directory, subdirectories will be recursively explored. +def run_tests( + tax_benefit_system: TaxBenefitSystem, + paths: str | Sequence[str], + options: Options = options, +) -> int: + """Runs all the YAML tests contained in a file or a directory. + + If ``path`` is a directory, subdirectories will be recursively explored. - :param .TaxBenefitSystem tax_benefit_system: the tax-benefit system to use to run the tests - :param str or list paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. - :param dict options: See more details below. + Args: + tax_benefit_system: the tax-benefit system to use to run the tests. + paths: A path, or a list of paths, towards the files or directories containing the tests to run. If a path is a directory, subdirectories will be recursively explored. + options: See more details below. - :raises :exc:`AssertionError`: if a test does not pass + Returns: + The number of successful tests executed. - :return: the number of sucessful tests excecuted + Raises: + :exc:`AssertionError`: if a test does not pass. **Testing options**: @@ -63,89 +146,95 @@ def run_tests(tax_benefit_system, paths, options = None): +-------------------------------+-----------+-------------------------------------------+ """ + argv = [] + plugins = [OpenFiscaPlugin(tax_benefit_system, options)] - argv = ["--capture", "no"] + if options.get("pdb"): + argv.append("--pdb") - if options.get('pdb'): - argv.append('--pdb') + if options.get("verbose"): + argv.append("--verbose") if isinstance(paths, str): paths = [paths] - if options is None: - options = {} - - return pytest.main([*argv, *paths] if True else paths, plugins = [OpenFiscaPlugin(tax_benefit_system, options)]) + return pytest.main([*argv, *paths], plugins=plugins) class YamlFile(pytest.File): - - def __init__(self, path, fspath, parent, tax_benefit_system, options): - super(YamlFile, self).__init__(path, parent) + def __init__(self, *, tax_benefit_system, options, **kwargs) -> None: + super().__init__(**kwargs) self.tax_benefit_system = tax_benefit_system self.options = options def collect(self): try: - tests = yaml.load(self.fspath.open(), Loader = Loader) + tests = yaml.load(open(self.path), Loader=Loader) except (yaml.scanner.ScannerError, yaml.parser.ParserError, TypeError): - message = os.linesep.join([ - traceback.format_exc(), - f"'{self.fspath}' is not a valid YAML file. Check the stack trace above for more details.", - ]) + message = os.linesep.join( + [ + traceback.format_exc(), + f"'{self.path}' is not a valid YAML file. Check the stack trace above for more details.", + ], + ) raise ValueError(message) if not isinstance(tests, list): - tests: List[Dict] = [tests] + tests: Sequence[dict] = [tests] for test in tests: if not self.should_ignore(test): - yield YamlItem.from_parent(self, - name = '', - baseline_tax_benefit_system = self.tax_benefit_system, - test = test, options = self.options) + yield YamlItem.from_parent( + self, + name="", + baseline_tax_benefit_system=self.tax_benefit_system, + test=test, + options=self.options, + ) def should_ignore(self, test): - name_filter = self.options.get('name_filter') + name_filter = self.options.get("name_filter") return ( name_filter is not None - and name_filter not in os.path.splitext(self.fspath.basename)[0] - and name_filter not in test.get('name', '') - and name_filter not in test.get('keywords', []) - ) + and name_filter not in os.path.splitext(os.path.basename(self.path))[0] + and name_filter not in test.get("name", "") + and name_filter not in test.get("keywords", []) + ) class YamlItem(pytest.Item): - """ - Terminal nodes of the test collection tree. - """ + """Terminal nodes of the test collection tree.""" - def __init__(self, name, parent, baseline_tax_benefit_system, test, options): - super(YamlItem, self).__init__(name, parent) + def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs) -> None: + super().__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options - self.test = test + self.test = build_test(test) self.simulation = None self.tax_benefit_system = None - def runtest(self): - self.name = self.test.get('name', '') - if not self.test.get('output'): - raise ValueError("Missing key 'output' in test '{}' in file '{}'".format(self.name, self.fspath)) + def runtest(self) -> None: + self.name = self.test.name - if not TEST_KEYWORDS.issuperset(self.test.keys()): - unexpected_keys = set(self.test.keys()).difference(TEST_KEYWORDS) - raise ValueError("Unexpected keys {} in test '{}' in file '{}'".format(unexpected_keys, self.name, self.fspath)) + if self.test.output is None: + msg = f"Missing key 'output' in test '{self.name}' in file '{self.path}'" + raise ValueError(msg) - self.tax_benefit_system = _get_tax_benefit_system(self.baseline_tax_benefit_system, self.test.get('reforms', []), self.test.get('extensions', [])) + self.tax_benefit_system = _get_tax_benefit_system( + self.baseline_tax_benefit_system, + self.test.reforms, + self.test.extensions, + ) builder = SimulationBuilder() - input = self.test.get('input', {}) - period = self.test.get('period') - max_spiral_loops = self.test.get('max_spiral_loops') - verbose = self.options.get('verbose') - performance_graph = self.options.get('performance_graph') - performance_tables = self.options.get('performance_tables') + input = self.test.input + period = self.test.period + max_spiral_loops = self.test.max_spiral_loops + verbose = self.options.get("verbose") + aggregate = self.options.get("aggregate") + max_depth = self.options.get("max_depth") + performance_graph = self.options.get("performance_graph") + performance_tables = self.options.get("performance_tables") try: builder.set_default_period(period) @@ -153,8 +242,12 @@ def runtest(self): except (VariableNotFound, SituationParsingError): raise except Exception as e: - error_message = os.linesep.join([str(e), '', f"Unexpected error raised while parsing '{self.fspath}'"]) - raise ValueError(error_message).with_traceback(sys.exc_info()[2]) from e # Keep the stack trace from the root error + error_message = os.linesep.join( + [str(e), "", f"Unexpected error raised while parsing '{self.path}'"], + ) + raise ValueError(error_message).with_traceback( + sys.exc_info()[2], + ) from e # Keep the stack trace from the root error if max_spiral_loops: self.simulation.max_spiral_loops = max_spiral_loops @@ -165,101 +258,130 @@ def runtest(self): finally: tracer = self.simulation.tracer if verbose: - self.print_computation_log(tracer) + self.print_computation_log(tracer, aggregate, max_depth) if performance_graph: self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) - def print_computation_log(self, tracer): - print("Computation log:") # noqa T001 - tracer.print_computation_log() + def print_computation_log(self, tracer, aggregate, max_depth) -> None: + tracer.print_computation_log(aggregate, max_depth) - def generate_performance_graph(self, tracer): - tracer.generate_performance_graph('.') + def generate_performance_graph(self, tracer) -> None: + tracer.generate_performance_graph(".") - def generate_performance_tables(self, tracer): - tracer.generate_performance_tables('.') + def generate_performance_tables(self, tracer) -> None: + tracer.generate_performance_tables(".") - def check_output(self): - output = self.test.get('output') + def check_output(self) -> None: + output = self.test.output if output is None: return for key, expected_value in output.items(): if self.tax_benefit_system.get_variable(key): # If key is a variable - self.check_variable(key, expected_value, self.test.get('period')) + self.check_variable(key, expected_value, self.test.period) elif self.simulation.populations.get(key): # If key is an entity singular for variable_name, value in expected_value.items(): - self.check_variable(variable_name, value, self.test.get('period')) + self.check_variable(variable_name, value, self.test.period) else: - population = self.simulation.get_population(plural = key) + population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural for instance_id, instance_values in expected_value.items(): for variable_name, value in instance_values.items(): entity_index = population.get_index(instance_id) - self.check_variable(variable_name, value, self.test.get('period'), entity_index) + self.check_variable( + variable_name, + value, + self.test.period, + entity_index, + ) else: raise VariableNotFound(key, self.tax_benefit_system) - def check_variable(self, variable_name, expected_value, period, entity_index = None): + def check_variable( + self, + variable_name: str, + expected_value, + period, + entity_index=None, + ): if self.should_ignore_variable(variable_name): - return + return None + if isinstance(expected_value, dict): for requested_period, expected_value_at_period in expected_value.items(): - self.check_variable(variable_name, expected_value_at_period, requested_period, entity_index) - return + self.check_variable( + variable_name, + expected_value_at_period, + requested_period, + entity_index, + ) + + return None actual_value = self.simulation.calculate(variable_name, period) if entity_index is not None: actual_value = actual_value[entity_index] + return assert_near( actual_value, expected_value, - absolute_error_margin = self.test.get('absolute_error_margin'), - message = f"{variable_name}@{period}: ", - relative_error_margin = self.test.get('relative_error_margin'), - ) - - def should_ignore_variable(self, variable_name): - only_variables = self.options.get('only_variables') - ignore_variables = self.options.get('ignore_variables') - variable_ignored = ignore_variables is not None and variable_name in ignore_variables - variable_not_tested = only_variables is not None and variable_name not in only_variables + self.test.absolute_error_margin[variable_name], + f"{variable_name}@{period}: ", + self.test.relative_error_margin[variable_name], + ) + + def should_ignore_variable(self, variable_name: str): + only_variables = self.options.get("only_variables") + ignore_variables = self.options.get("ignore_variables") + variable_ignored = ( + ignore_variables is not None and variable_name in ignore_variables + ) + variable_not_tested = ( + only_variables is not None and variable_name not in only_variables + ) return variable_ignored or variable_not_tested def repr_failure(self, excinfo): - if not isinstance(excinfo.value, (AssertionError, VariableNotFound, SituationParsingError)): - return super(YamlItem, self).repr_failure(excinfo) + if not isinstance( + excinfo.value, + (AssertionError, VariableNotFound, SituationParsingError), + ): + return super().repr_failure(excinfo) message = excinfo.value.args[0] if isinstance(excinfo.value, SituationParsingError): message = f"Could not parse situation described: {message}" - return os.linesep.join([ - f"{str(self.fspath)}:", - f" Test '{str(self.name)}':", - textwrap.indent(message, ' ') - ]) - + return os.linesep.join( + [ + f"{self.path!s}:", + f" Test '{self.name!s}':", + textwrap.indent(message, " "), + ], + ) -class OpenFiscaPlugin(object): - def __init__(self, tax_benefit_system, options): +class OpenFiscaPlugin: + def __init__(self, tax_benefit_system, options) -> None: self.tax_benefit_system = tax_benefit_system self.options = options def pytest_collect_file(self, parent, path): - """ - Called by pytest for all plugins. + """Called by pytest for all plugins. :return: The collector for test methods. """ if path.ext in [".yaml", ".yml"]: - return YamlFile.from_parent(parent, path = path, fspath = path, - tax_benefit_system = self.tax_benefit_system, - options = self.options) + return YamlFile.from_parent( + parent, + path=pathlib.Path(path), + tax_benefit_system=self.tax_benefit_system, + options=self.options, + ) + return None def _get_tax_benefit_system(baseline, reforms, extensions): @@ -269,17 +391,18 @@ def _get_tax_benefit_system(baseline, reforms, extensions): extensions = [extensions] # keep reforms order in cache, ignore extensions order - key = hash((id(baseline), ':'.join(reforms), frozenset(extensions))) + key = hash((id(baseline), ":".join(reforms), frozenset(extensions))) if _tax_benefit_system_cache.get(key): return _tax_benefit_system_cache.get(key) - current_tax_benefit_system = baseline + current_tax_benefit_system = baseline.clone() for reform_path in reforms: - current_tax_benefit_system = current_tax_benefit_system.apply_reform(reform_path) + current_tax_benefit_system = current_tax_benefit_system.apply_reform( + reform_path, + ) for extension in extensions: - current_tax_benefit_system = current_tax_benefit_system.clone() current_tax_benefit_system.load_extension(extension) _tax_benefit_system_cache[key] = current_tax_benefit_system diff --git a/openfisca_core/tracers/__init__.py b/openfisca_core/tracers/__init__.py index de489ad6d9..e59d0122a2 100644 --- a/openfisca_core/tracers/__init__.py +++ b/openfisca_core/tracers/__init__.py @@ -27,4 +27,6 @@ from .performance_log import PerformanceLog # noqa: F401 from .simple_tracer import SimpleTracer # noqa: F401 from .trace_node import TraceNode # noqa: F401 -from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant # noqa: F401 +from .tracing_parameter_node_at_instant import ( # noqa: F401 + TracingParameterNodeAtInstant, +) diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py index c785fd9395..6310eb8849 100644 --- a/openfisca_core/tracers/computation_log.py +++ b/openfisca_core/tracers/computation_log.py @@ -1,103 +1,120 @@ from __future__ import annotations import typing -from typing import List, Optional, Union +from typing import Union import numpy -from .. import tracers from openfisca_core.indexed_enums import EnumArray if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Array = Union[EnumArray, ArrayLike] class ComputationLog: - _full_tracer: tracers.FullTracer def __init__(self, full_tracer: tracers.FullTracer) -> None: self._full_tracer = full_tracer def display( - self, - value: Optional[Array], - ) -> str: + self, + value: Array | None, + ) -> str: if isinstance(value, EnumArray): value = value.decode_to_str() - return numpy.array2string(value, max_line_width = float("inf")) - - def _get_node_log( - self, - node: tracers.TraceNode, - depth: int, - aggregate: bool, - ) -> List[str]: + return numpy.array2string(value, max_line_width=float("inf")) - def print_line(depth: int, node: tracers.TraceNode) -> str: - indent = ' ' * depth - value = node.value + def lines( + self, + aggregate: bool = False, + max_depth: int | None = None, + ) -> list[str]: + depth = 1 - if value is None: - formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + lines_by_tree = [ + self._get_node_log(node, depth, aggregate, max_depth) + for node in self._full_tracer.trees + ] - elif aggregate: - try: - formatted_value = str({ - 'avg': numpy.mean(value), - 'max': numpy.max(value), - 'min': numpy.min(value), - }) + return self._flatten(lines_by_tree) - except TypeError: - formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + def print_log(self, aggregate=False, max_depth=None) -> None: + """Print the computation log of a simulation. - else: - formatted_value = self.display(value) + If ``aggregate`` is ``False`` (default), print the value of each + computed vector. - return f"{indent}{node.name}<{node.period}> >> {formatted_value}" + If ``aggregate`` is ``True``, only print the minimum, maximum, and + average value of each computed vector. - node_log = [print_line(depth, node)] + This mode is more suited for simulations on a large population. - children_logs = [ - self._get_node_log(child, depth + 1, aggregate) - for child - in node.children - ] + If ``max_depth`` is ``None`` (default), print the entire computation. - return node_log + self._flatten(children_logs) + If ``max_depth`` is set, for example to ``3``, only print computed + vectors up to a depth of ``max_depth``. + """ + for _line in self.lines(aggregate, max_depth): + pass - def _flatten( - self, - list_of_lists: List[List[str]], - ) -> List[str]: - return [item for _list in list_of_lists for item in _list] + def _get_node_log( + self, + node: tracers.TraceNode, + depth: int, + aggregate: bool, + max_depth: int | None, + ) -> list[str]: + if max_depth is not None and depth > max_depth: + return [] - def lines(self, aggregate: bool = False) -> List[str]: - depth = 1 + node_log = [self._print_line(depth, node, aggregate, max_depth)] - lines_by_tree = [ - self._get_node_log(node, depth, aggregate) - for node - in self._full_tracer.trees - ] + children_logs = [ + self._get_node_log(child, depth + 1, aggregate, max_depth) + for child in node.children + ] - return self._flatten(lines_by_tree) + return node_log + self._flatten(children_logs) - def print_log(self, aggregate = False) -> None: - """ - Print the computation log of a simulation. + def _print_line( + self, + depth: int, + node: tracers.TraceNode, + aggregate: bool, + max_depth: int | None, + ) -> str: + indent = " " * depth + value = node.value + + if value is None: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + + elif aggregate: + try: + formatted_value = str( + { + "avg": numpy.mean(value), + "max": numpy.max(value), + "min": numpy.min(value), + }, + ) + + except TypeError: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" - If ``aggregate`` is ``False`` (default), print the value of each - computed vector. + else: + formatted_value = self.display(value) - If ``aggregate`` is ``True``, only print the minimum, maximum, and - average value of each computed vector. + return f"{indent}{node.name}<{node.period}> >> {formatted_value}" - This mode is more suited for simulations on a large population. - """ - for line in self.lines(aggregate): - print(line) # noqa T001 + def _flatten( + self, + lists: list[list[str]], + ) -> list[str]: + return [item for list_ in lists for item in list_] diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index d51dd2576b..2090d537b8 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -1,22 +1,22 @@ from __future__ import annotations import typing -from typing import Dict, Optional, Union +from typing import Union import numpy -from openfisca_core import tracers from openfisca_core.indexed_enums import EnumArray if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Array = Union[EnumArray, ArrayLike] - Trace = Dict[str, dict] + Trace = dict[str, dict] class FlatTrace: - _full_tracer: tracers.FullTracer def __init__(self, full_tracer: tracers.FullTracer) -> None: @@ -35,32 +35,33 @@ def get_trace(self) -> dict: # calculation. # # We therefore use a non-overwriting update. - trace.update({ - key: node_trace - for key, node_trace in self._get_flat_trace(node).items() - if key not in trace - }) + trace.update( + { + key: node_trace + for key, node_trace in self._get_flat_trace(node).items() + if key not in trace + }, + ) return trace def get_serialized_trace(self) -> dict: return { - key: { - **flat_trace, - 'value': self.serialize(flat_trace['value']) - } + key: {**flat_trace, "value": self.serialize(flat_trace["value"])} for key, flat_trace in self.get_trace().items() - } + } def serialize( - self, - value: Optional[Array], - ) -> Union[Optional[Array], list]: + self, + value: Array | None, + ) -> Array | None | list: if isinstance(value, EnumArray): value = value.decode_to_str() - if isinstance(value, numpy.ndarray) and \ - numpy.issubdtype(value.dtype, numpy.dtype(bytes)): + if isinstance(value, numpy.ndarray) and numpy.issubdtype( + value.dtype, + numpy.dtype(bytes), + ): value = value.astype(numpy.dtype(str)) if isinstance(value, numpy.ndarray): @@ -69,26 +70,20 @@ def serialize( return value def _get_flat_trace( - self, - node: tracers.TraceNode, - ) -> Trace: + self, + node: tracers.TraceNode, + ) -> Trace: key = self.key(node) - node_trace = { + return { key: { - 'dependencies': [ - self.key(child) for child in node.children - ], - 'parameters': { - self.key(parameter): - self.serialize(parameter.value) - for parameter - in node.parameters - }, - 'value': node.value, - 'calculation_time': node.calculation_time(), - 'formula_time': node.formula_time(), + "dependencies": [self.key(child) for child in node.children], + "parameters": { + self.key(parameter): self.serialize(parameter.value) + for parameter in node.parameters }, - } - - return node_trace + "value": node.value, + "calculation_time": node.calculation_time(), + "formula_time": node.formula_time(), + }, + } diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 3fa46de5ab..9fa94d5ab5 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -1,24 +1,25 @@ from __future__ import annotations -import time import typing -from typing import Dict, Iterator, List, Optional, Union +from typing import Union + +import time -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: + from collections.abc import Iterator from numpy.typing import ArrayLike from openfisca_core.periods import Period - Stack = List[Dict[str, Union[str, Period]]] + Stack = list[dict[str, Union[str, Period]]] class FullTracer: - _simple_tracer: tracers.SimpleTracer _trees: list - _current_node: Optional[tracers.TraceNode] + _current_node: tracers.TraceNode | None def __init__(self) -> None: self._simple_tracer = tracers.SimpleTracer() @@ -26,24 +27,24 @@ def __init__(self) -> None: self._current_node = None def record_calculation_start( - self, - variable: str, - period: Period, - ) -> None: + self, + variable: str, + period: Period | int, + ) -> None: self._simple_tracer.record_calculation_start(variable, period) self._enter_calculation(variable, period) self._record_start_time() def _enter_calculation( - self, - variable: str, - period: Period, - ) -> None: + self, + variable: str, + period: Period, + ) -> None: new_node = tracers.TraceNode( - name = variable, - period = period, - parent = self._current_node, - ) + name=variable, + period=period, + parent=self._current_node, + ) if self._current_node is None: self._trees.append(new_node) @@ -54,21 +55,20 @@ def _enter_calculation( self._current_node = new_node def record_parameter_access( - self, - parameter: str, - period: Period, - value: ArrayLike, - ) -> None: - + self, + parameter: str, + period: Period, + value: ArrayLike, + ) -> None: if self._current_node is not None: self._current_node.parameters.append( - tracers.TraceNode(name = parameter, period = period, value = value), - ) + tracers.TraceNode(name=parameter, period=period, value=value), + ) def _record_start_time( - self, - time_in_s: Optional[float] = None, - ) -> None: + self, + time_in_s: float | None = None, + ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -85,9 +85,9 @@ def record_calculation_end(self) -> None: self._exit_calculation() def _record_end_time( - self, - time_in_s: Optional[float] = None, - ) -> None: + self, + time_in_s: float | None = None, + ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -103,7 +103,7 @@ def stack(self) -> Stack: return self._simple_tracer.stack @property - def trees(self) -> List[tracers.TraceNode]: + def trees(self) -> list[tracers.TraceNode]: return self._trees @property @@ -121,8 +121,8 @@ def flat_trace(self) -> tracers.FlatTrace: def _get_time_in_sec(self) -> float: return time.time_ns() / (10**9) - def print_computation_log(self, aggregate = False): - self.computation_log.print_log(aggregate) + def print_computation_log(self, aggregate=False, max_depth=None) -> None: + self.computation_log.print_log(aggregate, max_depth) def generate_performance_graph(self, dir_path: str) -> None: self.performance_log.generate_graph(dir_path) @@ -133,19 +133,13 @@ def generate_performance_tables(self, dir_path: str) -> None: def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: tree_call = tree.name == variable children_calls = sum( - self._get_nb_requests(child, variable) - for child - in tree.children - ) + self._get_nb_requests(child, variable) for child in tree.children + ) return tree_call + children_calls def get_nb_requests(self, variable: str) -> int: - return sum( - self._get_nb_requests(tree, variable) - for tree - in self.trees - ) + return sum(self._get_nb_requests(tree, variable) for tree in self.trees) def get_flat_trace(self) -> dict: return self.flat_trace.get_trace() @@ -154,7 +148,6 @@ def get_serialized_flat_trace(self) -> dict: return self.flat_trace.get_serialized_trace() def browse_trace(self) -> Iterator[tracers.TraceNode]: - def _browse_node(node): yield node diff --git a/openfisca_core/tracers/performance_log.py b/openfisca_core/tracers/performance_log.py index 754d7f8056..f69a3dd3a2 100644 --- a/openfisca_core/tracers/performance_log.py +++ b/openfisca_core/tracers/performance_log.py @@ -1,36 +1,36 @@ from __future__ import annotations +import typing + import csv import importlib.resources import itertools import json import os -import typing -from .. import tracers +from openfisca_core import tracers if typing.TYPE_CHECKING: - Trace = typing.Dict[str, dict] - Calculation = typing.Tuple[str, dict] - SortedTrace = typing.List[Calculation] + Trace = dict[str, dict] + Calculation = tuple[str, dict] + SortedTrace = list[Calculation] class PerformanceLog: - def __init__(self, full_tracer: tracers.FullTracer) -> None: self._full_tracer = full_tracer def generate_graph(self, dir_path: str) -> None: - with open(os.path.join(dir_path, 'performance_graph.html'), 'w') as f: + with open(os.path.join(dir_path, "performance_graph.html"), "w") as f: template = importlib.resources.read_text( - 'openfisca_core.scripts.assets', - 'index.html', - ) + "openfisca_core.scripts.assets", + "index.html", + ) perf_graph_html = template.replace( - '{{data}}', + "{{data}}", json.dumps(self._json()), - ) + ) f.write(perf_graph_html) @@ -39,94 +39,95 @@ def generate_performance_tables(self, dir_path: str) -> None: csv_rows = [ { - 'name': key, - 'calculation_time': trace['calculation_time'], - 'formula_time': trace['formula_time'], - } - for key, trace - in flat_trace.items() - ] + "name": key, + "calculation_time": trace["calculation_time"], + "formula_time": trace["formula_time"], + } + for key, trace in flat_trace.items() + ] self._write_csv( - os.path.join(dir_path, 'performance_table.csv'), + os.path.join(dir_path, "performance_table.csv"), csv_rows, - ) + ) aggregated_csv_rows = [ - {'name': key, **aggregated_time} - for key, aggregated_time - in self.aggregate_calculation_times(flat_trace).items() - ] + {"name": key, **aggregated_time} + for key, aggregated_time in self.aggregate_calculation_times( + flat_trace, + ).items() + ] self._write_csv( - os.path.join(dir_path, 'aggregated_performance_table.csv'), + os.path.join(dir_path, "aggregated_performance_table.csv"), aggregated_csv_rows, - ) + ) def aggregate_calculation_times( - self, - flat_trace: Trace, - ) -> typing.Dict[str, dict]: - + self, + flat_trace: Trace, + ) -> dict[str, dict]: def _aggregate_calculations(calculations: list) -> dict: calculation_count = len(calculations) calculation_time = sum( - calculation[1]['calculation_time'] - for calculation - in calculations - ) + calculation[1]["calculation_time"] for calculation in calculations + ) formula_time = sum( - calculation[1]['formula_time'] - for calculation - in calculations - ) + calculation[1]["formula_time"] for calculation in calculations + ) return { - 'calculation_count': calculation_count, - 'calculation_time': tracers.TraceNode.round(calculation_time), - 'formula_time': tracers.TraceNode.round(formula_time), - 'avg_calculation_time': tracers.TraceNode.round(calculation_time / calculation_count), - 'avg_formula_time': tracers.TraceNode.round(formula_time / calculation_count), - } + "calculation_count": calculation_count, + "calculation_time": tracers.TraceNode.round(calculation_time), + "formula_time": tracers.TraceNode.round(formula_time), + "avg_calculation_time": tracers.TraceNode.round( + calculation_time / calculation_count, + ), + "avg_formula_time": tracers.TraceNode.round( + formula_time / calculation_count, + ), + } def _groupby(calculation: Calculation) -> str: - return calculation[0].split('<')[0] + return calculation[0].split("<")[0] all_calculations: SortedTrace = sorted(flat_trace.items()) return { variable_name: _aggregate_calculations(list(calculations)) - for variable_name, calculations - in itertools.groupby(all_calculations, _groupby) - } + for variable_name, calculations in itertools.groupby( + all_calculations, + _groupby, + ) + } def _json(self) -> dict: children = [self._json_tree(tree) for tree in self._full_tracer.trees] - calculations_total_time = sum(child['value'] for child in children) + calculations_total_time = sum(child["value"] for child in children) return { - 'name': 'All calculations', - 'value': calculations_total_time, - 'children': children, - } + "name": "All calculations", + "value": calculations_total_time, + "children": children, + } def _json_tree(self, tree: tracers.TraceNode) -> dict: calculation_total_time = tree.calculation_time() children = [self._json_tree(child) for child in tree.children] return { - 'name': f"{tree.name}<{tree.period}>", - 'value': calculation_total_time, - 'children': children, - } + "name": f"{tree.name}<{tree.period}>", + "value": calculation_total_time, + "children": children, + } - def _write_csv(self, path: str, rows: typing.List[dict]) -> None: + def _write_csv(self, path: str, rows: list[dict]) -> None: fieldnames = list(rows[0].keys()) - with open(path, 'w') as csv_file: - writer = csv.DictWriter(csv_file, fieldnames = fieldnames) + with open(path, "w") as csv_file: + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer.writeheader() for row in rows: diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 2fa98c6582..84328730ef 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,30 +1,29 @@ from __future__ import annotations import typing -from typing import Dict, List, Union +from typing import Union if typing.TYPE_CHECKING: from numpy.typing import ArrayLike from openfisca_core.periods import Period - Stack = List[Dict[str, Union[str, Period]]] + Stack = list[dict[str, Union[str, Period]]] class SimpleTracer: - _stack: Stack def __init__(self) -> None: self._stack = [] - def record_calculation_start(self, variable: str, period: Period) -> None: - self.stack.append({'name': variable, 'period': period}) + def record_calculation_start(self, variable: str, period: Period | int) -> None: + self.stack.append({"name": variable, "period": period}) def record_calculation_result(self, value: ArrayLike) -> None: pass # ignore calculation result - def record_parameter_access(self, parameter: str, period, value): + def record_parameter_access(self, parameter: str, period, value) -> None: pass def record_calculation_end(self) -> None: diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index 93b630886c..ff55a5714f 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -1,8 +1,9 @@ from __future__ import annotations -import dataclasses import typing +import dataclasses + if typing.TYPE_CHECKING: import numpy @@ -17,10 +18,10 @@ class TraceNode: name: str period: Period - parent: typing.Optional[TraceNode] = None - children: typing.List[TraceNode] = dataclasses.field(default_factory = list) - parameters: typing.List[TraceNode] = dataclasses.field(default_factory = list) - value: typing.Optional[Array] = None + parent: TraceNode | None = None + children: list[TraceNode] = dataclasses.field(default_factory=list) + parameters: list[TraceNode] = dataclasses.field(default_factory=list) + value: Array | None = None start: float = 0 end: float = 0 @@ -34,15 +35,10 @@ def calculation_time(self, round_: bool = True) -> Time: def formula_time(self) -> float: children_calculation_time = sum( - child.calculation_time(round_ = False) - for child - in self.children - ) + child.calculation_time(round_=False) for child in self.children + ) - result = ( - + self.calculation_time(round_ = False) - - children_calculation_time - ) + result = +self.calculation_time(round_=False) - children_calculation_time return self.round(result) @@ -51,4 +47,4 @@ def append_child(self, node: TraceNode) -> None: @staticmethod def round(time: Time) -> float: - return float(f'{time:.4g}') # Keep only 4 significant figures + return float(f"{time:.4g}") # Keep only 4 significant figures diff --git a/openfisca_core/tracers/tracing_parameter_node_at_instant.py b/openfisca_core/tracers/tracing_parameter_node_at_instant.py index 89d9b8fb01..074c24221d 100644 --- a/openfisca_core/tracers/tracing_parameter_node_at_instant.py +++ b/openfisca_core/tracers/tracing_parameter_node_at_instant.py @@ -7,70 +7,77 @@ from openfisca_core import parameters -from .. import tracers - ParameterNode = Union[ parameters.VectorialParameterNodeAtInstant, parameters.ParameterNodeAtInstant, - ] +] if typing.TYPE_CHECKING: from numpy.typing import ArrayLike + from openfisca_core import tracers + Child = Union[ParameterNode, ArrayLike] class TracingParameterNodeAtInstant: - def __init__( - self, - parameter_node_at_instant: ParameterNode, - tracer: tracers.FullTracer, - ) -> None: + self, + parameter_node_at_instant: ParameterNode, + tracer: tracers.FullTracer, + ) -> None: self.parameter_node_at_instant = parameter_node_at_instant self.tracer = tracer def __getattr__( - self, - key: str, - ) -> Union[TracingParameterNodeAtInstant, Child]: + self, + key: str, + ) -> TracingParameterNodeAtInstant | Child: child = getattr(self.parameter_node_at_instant, key) return self.get_traced_child(child, key) + def __contains__(self, key) -> bool: + return key in self.parameter_node_at_instant + + def __iter__(self): + return iter(self.parameter_node_at_instant) + def __getitem__( - self, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + self, + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: child = self.parameter_node_at_instant[key] return self.get_traced_child(child, key) def get_traced_child( - self, - child: Child, - key: Union[str, ArrayLike], - ) -> Union[TracingParameterNodeAtInstant, Child]: + self, + child: Child, + key: str | ArrayLike, + ) -> TracingParameterNodeAtInstant | Child: period = self.parameter_node_at_instant._instant_str if isinstance( - child, - (parameters.ParameterNodeAtInstant, parameters.VectorialParameterNodeAtInstant), - ): + child, + ( + parameters.ParameterNodeAtInstant, + parameters.VectorialParameterNodeAtInstant, + ), + ): return TracingParameterNodeAtInstant(child, self.tracer) - if not isinstance(key, str) or \ - isinstance( - self.parameter_node_at_instant, - parameters.VectorialParameterNodeAtInstant, - ): + if not isinstance(key, str) or isinstance( + self.parameter_node_at_instant, + parameters.VectorialParameterNodeAtInstant, + ): # In case of vectorization, we keep the parent node name as, for # instance, rate[status].zone1 is best described as the value of # "rate". name = self.parameter_node_at_instant._name else: - name = '.'.join([self.parameter_node_at_instant._name, key]) + name = f"{self.parameter_node_at_instant._name}.{key}" - if isinstance(child, (numpy.ndarray,) + parameters.ALLOWED_PARAM_TYPES): + if isinstance(child, (numpy.ndarray, *parameters.ALLOWED_PARAM_TYPES)): self.tracer.record_parameter_access(name, period, child) return child diff --git a/openfisca_core/types.py b/openfisca_core/types.py new file mode 100644 index 0000000000..711e6c512f --- /dev/null +++ b/openfisca_core/types.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence, Sized +from numpy.typing import NDArray +from typing import Any, NewType, TypeVar, Union +from typing_extensions import Protocol, TypeAlias + +import numpy +import pendulum + +_N_co = TypeVar("_N_co", bound=numpy.generic, covariant=True) + +#: Type representing an numpy array. +Array: TypeAlias = NDArray[_N_co] + +_L = TypeVar("_L") + +#: Type representing an array-like object. +ArrayLike: TypeAlias = Sequence[_L] + +#: Generic type vars. +_T_co = TypeVar("_T_co", covariant=True) + + +# Entities + +#: For example "person". +EntityKey = NewType("EntityKey", str) + +#: For example "persons". +EntityPlural = NewType("EntityPlural", str) + +#: For example "principal". +RoleKey = NewType("RoleKey", str) + +#: For example "parents". +RolePlural = NewType("RolePlural", str) + + +class CoreEntity(Protocol): + key: EntityKey + plural: EntityPlural + + def check_role_validity(self, role: object, /) -> None: ... + def check_variable_defined_for_entity( + self, + variable_name: VariableName, + /, + ) -> None: ... + def get_variable( + self, + variable_name: VariableName, + check_existence: bool = ..., + /, + ) -> None | Variable: ... + + +class SingleEntity(CoreEntity, Protocol): ... + + +class GroupEntity(CoreEntity, Protocol): ... + + +class Role(Protocol): + entity: GroupEntity + max: int | None + subroles: None | Iterable[Role] + + @property + def key(self, /) -> RoleKey: ... + @property + def plural(self, /) -> None | RolePlural: ... + + +# Holders + + +class Holder(Protocol): + def clone(self, population: Any, /) -> Holder: ... + def get_memory_usage(self, /) -> Any: ... + + +# Parameters + + +class ParameterNodeAtInstant(Protocol): ... + + +# Periods + +#: For example "2000-01". +InstantStr = NewType("InstantStr", str) + +#: For example "1:2000-01-01:day". +PeriodStr = NewType("PeriodStr", str) + + +class Container(Protocol[_T_co]): + def __contains__(self, item: object, /) -> bool: ... + + +class Indexable(Protocol[_T_co]): + def __getitem__(self, index: int, /) -> _T_co: ... + + +class DateUnit(Container[str], Protocol): + def upper(self, /) -> str: ... + + +class Instant(Indexable[int], Iterable[int], Sized, Protocol): + @property + def year(self, /) -> int: ... + @property + def month(self, /) -> int: ... + @property + def day(self, /) -> int: ... + @property + def date(self, /) -> pendulum.Date: ... + def __lt__(self, other: object, /) -> bool: ... + def __le__(self, other: object, /) -> bool: ... + def offset(self, offset: str | int, unit: DateUnit, /) -> None | Instant: ... + + +class Period(Indexable[Union[DateUnit, Instant, int]], Protocol): + @property + def unit(self, /) -> DateUnit: ... + @property + def start(self, /) -> Instant: ... + @property + def size(self, /) -> int: ... + @property + def stop(self, /) -> Instant: ... + def offset(self, offset: str | int, unit: None | DateUnit = None, /) -> Period: ... + + +# Populations + + +class Population(Protocol): + entity: Any + + def get_holder(self, variable_name: VariableName, /) -> Any: ... + + +# Simulations + + +class Simulation(Protocol): + def calculate(self, variable_name: VariableName, period: Any, /) -> Any: ... + def calculate_add(self, variable_name: VariableName, period: Any, /) -> Any: ... + def calculate_divide(self, variable_name: VariableName, period: Any, /) -> Any: ... + def get_population(self, plural: None | str, /) -> Any: ... + + +# Tax-Benefit systems + + +class TaxBenefitSystem(Protocol): + person_entity: Any + + def get_variable( + self, + variable_name: VariableName, + check_existence: bool = ..., + /, + ) -> None | Variable: ... + + +# Variables + +#: For example "salary". +VariableName = NewType("VariableName", str) + + +class Variable(Protocol): + entity: Any + name: VariableName + + +class Formula(Protocol): + def __call__( + self, + population: Population, + instant: Instant, + params: Params, + /, + ) -> Array[Any]: ... + + +class Params(Protocol): + def __call__(self, instant: Instant, /) -> ParameterNodeAtInstant: ... diff --git a/openfisca_core/types/__init__.py b/openfisca_core/types/__init__.py deleted file mode 100644 index e6f72ecb3d..0000000000 --- a/openfisca_core/types/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Data types and protocols used by OpenFisca Core. - -The type definitions included in this sub-package are intented mostly for -contributors, to help them better document contracts and behaviours. - -Official Public API: - * :data:`.ArrayLike` - * :attr:`.ArrayType` - * :class:`.RoleLike` - * :class:`.Builder` - * :class:`.Descriptor` - * :class:`.HasHolders` - * :class:`.HasPlural` - * :class:`.HasVariables` - * :class:`.SupportsEncode` - * :class:`.SupportsFormula` - * :class:`.SupportsRole` - -Note: - How imports are being used today:: - - from openfisca_core.types import * # Bad - from openfisca_core.types.data_types.arrays import ArrayLike # Bad - - The previous examples provoke cyclic dependency problems, that prevents us - from modularizing the different components of the library, so as to make - them easier to test and to maintain. - - How could them be used after the next major release:: - - from openfisca_core.types import ArrayLike - - ArrayLike # Good: import types as publicly exposed - - .. seealso:: `PEP8#Imports`_ and `OpenFisca's Styleguide`_. - - .. _PEP8#Imports: - https://www.python.org/dev/peps/pep-0008/#imports - - .. _OpenFisca's Styleguide: - https://github.com/openfisca/openfisca-core/blob/master/STYLEGUIDE.md - -""" - -# Official Public API - -from .data_types import ( # noqa: F401 - ArrayLike, - ArrayType, - RoleLike, - ) - -__all__ = ["ArrayLike", "ArrayType", "RoleLike"] - -from .protocols import ( # noqa: F401 - Builder, - Descriptor, - HasHolders, - HasPlural, - HasVariables, - SupportsEncode, - SupportsFormula, - SupportsRole, - ) - -__all__ = ["Builder", "Descriptor", "HasHolders", "HasPlural", *__all__] -__all__ = ["HasVariables", "SupportsEncode", "SupportsFormula", *__all__] -__all__ = ["SupportsRole", *__all__] diff --git a/openfisca_core/types/data_types/__init__.py b/openfisca_core/types/data_types/__init__.py deleted file mode 100644 index 503b0fc0d8..0000000000 --- a/openfisca_core/types/data_types/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .arrays import ArrayLike, ArrayType # noqa: F401 -from .roles import RoleLike # noqa: F401 diff --git a/openfisca_core/types/data_types/arrays.py b/openfisca_core/types/data_types/arrays.py deleted file mode 100644 index 3e941b4ab1..0000000000 --- a/openfisca_core/types/data_types/arrays.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Sequence, TypeVar, Union - -from nptyping import NDArray as ArrayType - -T = TypeVar("T", bool, bytes, float, int, object, str) - -A = Union[ - ArrayType[bool], - ArrayType[bytes], - ArrayType[float], - ArrayType[int], - ArrayType[object], - ArrayType[str], - ] - -ArrayLike = Union[A, Sequence[T]] -""":obj:`typing.Generic`: Type of any castable to :class:`numpy.ndarray`. - -These include any :obj:`numpy.ndarray` and sequences (like -:obj:`list`, :obj:`tuple`, and so on). - -Examples: - >>> ArrayLike[float] - typing.Union[numpy.ndarray, typing.Sequence[float]] - - >>> ArrayLike[str] - typing.Union[numpy.ndarray, typing.Sequence[str]] - -Note: - It is possible since numpy version 1.21 to specify the type of an - array, thanks to `numpy.typing.NDArray`_:: - - from numpy.typing import NDArray - NDArray[numpy.float64] - - `mypy`_ provides `duck type compatibility`_, so an :obj:`int` is - considered to be valid whenever a :obj:`float` is expected. - -Todo: - * Refactor once numpy version >= 1.21 is used. - -.. versionadded:: 35.5.0 - -.. versionchanged:: 35.6.0 - Moved to :mod:`.types` - -.. _mypy: - https://mypy.readthedocs.io/en/stable/ - -.. _duck type compatibility: - https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html - -.. _numpy.typing.NDArray: - https://numpy.org/doc/stable/reference/typing.html#numpy.typing.NDArray - -""" diff --git a/openfisca_core/types/data_types/roles.py b/openfisca_core/types/data_types/roles.py deleted file mode 100644 index d6de0d3976..0000000000 --- a/openfisca_core/types/data_types/roles.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Iterable, Optional -from typing_extensions import TypedDict - - -class RoleLike(TypedDict, total = False): - """Base type for any data castable to a role-like model. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - key: str - plural: Optional[str] - label: Optional[str] - doc: Optional[str] - max: Optional[int] - subroles: Optional[Iterable[str]] diff --git a/openfisca_core/types/protocols/__init__.py b/openfisca_core/types/protocols/__init__.py deleted file mode 100644 index d44e025734..0000000000 --- a/openfisca_core/types/protocols/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .builder import Builder # noqa: F401 -from .descriptor import Descriptor # noqa: F401 -from .has_holders import HasHolders # noqa: F401 -from .has_plural import HasPlural # noqa: F401 -from .has_variables import HasVariables # noqa: F401 -from .supports_encode import SupportsEncode # noqa: F401 -from .supports_formula import SupportsFormula # noqa: F401 -from .supports_role import SupportsRole # noqa: F401 diff --git a/openfisca_core/types/protocols/_documentable.py b/openfisca_core/types/protocols/_documentable.py deleted file mode 100644 index 8708fd9391..0000000000 --- a/openfisca_core/types/protocols/_documentable.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing_extensions import Protocol - - -class Documentable(Protocol): - """Base type for any model that can be documented. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - key: str diff --git a/openfisca_core/types/protocols/builder.py b/openfisca_core/types/protocols/builder.py index cd57484276..7120daa53e 100644 --- a/openfisca_core/types/protocols/builder.py +++ b/openfisca_core/types/protocols/builder.py @@ -1,13 +1,13 @@ from __future__ import annotations -import abc from typing import Iterable, Sequence, Type, TypeVar - from typing_extensions import Protocol -RT = TypeVar("RT", covariant = True) -ET = TypeVar("ET", covariant = True) -EL = TypeVar("EL", contravariant = True) +import abc + +RT = TypeVar("RT", covariant=True) +ET = TypeVar("ET", covariant=True) +EL = TypeVar("EL", contravariant=True) class Builder(Protocol[RT, ET, EL]): @@ -22,8 +22,7 @@ class Builder(Protocol[RT, ET, EL]): """ @abc.abstractmethod - def __init__(self, __builder: RT, __buildee: Type[ET]) -> None: - ... + def __init__(self, __builder: RT, __buildee: Type[ET]) -> None: ... @abc.abstractmethod def __call__(self, __items: Iterable[EL]) -> Sequence[ET]: diff --git a/openfisca_core/types/protocols/descriptor.py b/openfisca_core/types/protocols/descriptor.py index 296f61d2d9..aa5406960f 100644 --- a/openfisca_core/types/protocols/descriptor.py +++ b/openfisca_core/types/protocols/descriptor.py @@ -1,8 +1,7 @@ from typing import Any, Optional, Type, TypeVar - from typing_extensions import Protocol -T = TypeVar("T", covariant = True) +T = TypeVar("T", covariant=True) class Descriptor(Protocol[T]): @@ -16,8 +15,6 @@ class Descriptor(Protocol[T]): """ - def __get__(self, __instance: Any, __owner: Type[Any]) -> Optional[T]: - ... + def __get__(self, __instance: Any, __owner: Type[Any]) -> Optional[T]: ... - def __set__(self, __instance: Any, __value: Any) -> None: - ... + def __set__(self, __instance: Any, __value: Any) -> None: ... diff --git a/openfisca_core/types/protocols/has_holders.py b/openfisca_core/types/protocols/has_holders.py deleted file mode 100644 index 95a1a53b2a..0000000000 --- a/openfisca_core/types/protocols/has_holders.py +++ /dev/null @@ -1,22 +0,0 @@ -import abc -from typing import Any, Optional - -from typing_extensions import Protocol - - -class HasHolders(Protocol): - """Base type for any population-like model. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - @abc.abstractmethod - def get_holder(self, __arg1: str) -> Optional[Any]: - """A concrete representable model implements :meth:`.get_holder`.""" - - ... diff --git a/openfisca_core/types/protocols/has_plural.py b/openfisca_core/types/protocols/has_plural.py deleted file mode 100644 index 01ef48edb8..0000000000 --- a/openfisca_core/types/protocols/has_plural.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Any, Iterator, Tuple - -import typing_extensions -from typing_extensions import Protocol - -from ._documentable import Documentable -from .has_variables import HasVariables -from .supports_formula import SupportsFormula - -T = HasVariables -V = SupportsFormula - - -@typing_extensions.runtime_checkable -class HasPlural(Documentable, Protocol): - """Base type for any entity-like model. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - plural: str - - @abc.abstractmethod - def __repr__(self) -> str: - ... - - @abc.abstractmethod - def __str__(self) -> str: - ... - - @abc.abstractmethod - def __iter__(self) -> Iterator[Tuple[str, Any]]: - ... diff --git a/openfisca_core/types/protocols/has_variables.py b/openfisca_core/types/protocols/has_variables.py deleted file mode 100644 index 3a3edea6eb..0000000000 --- a/openfisca_core/types/protocols/has_variables.py +++ /dev/null @@ -1,22 +0,0 @@ -import abc -from typing import Any, Optional - -from typing_extensions import Protocol - - -class HasVariables(Protocol): - """Base type for any ruleset-like model. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - @abc.abstractmethod - def get_variable(self, __arg1: str, __arg2: bool = False) -> Optional[Any]: - """A concrete representable model implements :meth:`.get_variable`.""" - - ... diff --git a/openfisca_core/types/protocols/supports_encode.py b/openfisca_core/types/protocols/supports_encode.py deleted file mode 100644 index cae51b1d33..0000000000 --- a/openfisca_core/types/protocols/supports_encode.py +++ /dev/null @@ -1,23 +0,0 @@ -import abc -from typing import Any - -from typing_extensions import Protocol - - -class SupportsEncode(Protocol): - """Base type for any model implementing a literal list of choices. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.8.0 - - """ - - @classmethod - @abc.abstractmethod - def encode(cls, array: Any) -> Any: - """A concrete encodable model implements :meth:`.encode`.""" - - ... diff --git a/openfisca_core/types/protocols/supports_formula.py b/openfisca_core/types/protocols/supports_formula.py deleted file mode 100644 index 307190f324..0000000000 --- a/openfisca_core/types/protocols/supports_formula.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing_extensions import Protocol - - -class SupportsFormula(Protocol): - """Base type for any modelable element of the legislation. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - ... diff --git a/openfisca_core/types/protocols/supports_role.py b/openfisca_core/types/protocols/supports_role.py deleted file mode 100644 index 6e8a14db55..0000000000 --- a/openfisca_core/types/protocols/supports_role.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import abc -from typing import Any, Optional, Iterator, Sequence, Tuple - -import typing_extensions -from typing_extensions import Protocol - -from ..data_types import RoleLike -from ._documentable import Documentable -from .has_plural import HasPlural - -R = RoleLike -G = HasPlural - - -@typing_extensions.runtime_checkable -class SupportsRole(Documentable, Protocol): - """Base type for any role-like model. - - Type-checking against abstractions rather than implementations helps in - (a) decoupling the codebse, thanks to structural subtyping, and - (b) documenting/enforcing the blueprints of the different OpenFisca models. - - .. versionadded:: 35.7.0 - - """ - - max: Optional[int] - subroles: Optional[Sequence[SupportsRole]] - - @abc.abstractmethod - def __repr__(self) -> str: - ... - - @abc.abstractmethod - def __str__(self) -> str: - ... - - @abc.abstractmethod - def __iter__(self) -> Iterator[Tuple[str, Any]]: - ... - - @abc.abstractmethod - def __init__(self, __arg1: R, __arg2: G) -> None: - ... diff --git a/openfisca_core/variables/__init__.py b/openfisca_core/variables/__init__.py index 3decaf8f42..1ab191c5ce 100644 --- a/openfisca_core/variables/__init__.py +++ b/openfisca_core/variables/__init__.py @@ -21,6 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .config import VALUE_TYPES, FORMULA_NAME_PREFIX # noqa: F401 +from .config import FORMULA_NAME_PREFIX, VALUE_TYPES # noqa: F401 from .helpers import get_annualized_variable, get_neutralized_variable # noqa: F401 from .variable import Variable # noqa: F401 diff --git a/openfisca_core/variables/config.py b/openfisca_core/variables/config.py index b260bb3dd9..54270145bf 100644 --- a/openfisca_core/variables/config.py +++ b/openfisca_core/variables/config.py @@ -5,50 +5,49 @@ from openfisca_core import indexed_enums from openfisca_core.indexed_enums import Enum - VALUE_TYPES = { bool: { - 'dtype': numpy.bool_, - 'default': False, - 'json_type': 'boolean', - 'formatted_value_type': 'Boolean', - 'is_period_size_independent': True - }, + "dtype": numpy.bool_, + "default": False, + "json_type": "boolean", + "formatted_value_type": "Boolean", + "is_period_size_independent": True, + }, int: { - 'dtype': numpy.int32, - 'default': 0, - 'json_type': 'integer', - 'formatted_value_type': 'Int', - 'is_period_size_independent': False - }, + "dtype": numpy.int32, + "default": 0, + "json_type": "integer", + "formatted_value_type": "Int", + "is_period_size_independent": False, + }, float: { - 'dtype': numpy.float32, - 'default': 0, - 'json_type': 'number', - 'formatted_value_type': 'Float', - 'is_period_size_independent': False, - }, + "dtype": numpy.float32, + "default": 0, + "json_type": "number", + "formatted_value_type": "Float", + "is_period_size_independent": False, + }, str: { - 'dtype': object, - 'default': '', - 'json_type': 'string', - 'formatted_value_type': 'String', - 'is_period_size_independent': True - }, + "dtype": object, + "default": "", + "json_type": "string", + "formatted_value_type": "String", + "is_period_size_independent": True, + }, Enum: { - 'dtype': indexed_enums.ENUM_ARRAY_DTYPE, - 'json_type': 'string', - 'formatted_value_type': 'String', - 'is_period_size_independent': True, - }, + "dtype": indexed_enums.ENUM_ARRAY_DTYPE, + "json_type": "string", + "formatted_value_type": "String", + "is_period_size_independent": True, + }, datetime.date: { - 'dtype': 'datetime64[D]', - 'default': datetime.date.fromtimestamp(0), # 0 == 1970-01-01 - 'json_type': 'string', - 'formatted_value_type': 'Date', - 'is_period_size_independent': True, - }, - } + "dtype": "datetime64[D]", + "default": datetime.date.fromtimestamp(0), # 0 == 1970-01-01 + "json_type": "string", + "formatted_value_type": "Date", + "is_period_size_independent": True, + }, +} -FORMULA_NAME_PREFIX = 'formula' +FORMULA_NAME_PREFIX = "formula" diff --git a/openfisca_core/variables/helpers.py b/openfisca_core/variables/helpers.py index 335a585498..5038a78240 100644 --- a/openfisca_core/variables/helpers.py +++ b/openfisca_core/variables/helpers.py @@ -1,23 +1,24 @@ from __future__ import annotations import sortedcontainers -from typing import Optional +from openfisca_core import variables from openfisca_core.periods import Period -from .. import variables - -def get_annualized_variable(variable: variables.Variable, annualization_period: Optional[Period] = None) -> variables.Variable: - """ - Returns a clone of ``variable`` that is annualized for the period ``annualization_period``. +def get_annualized_variable( + variable: variables.Variable, + annualization_period: Period | None = None, +) -> variables.Variable: + """Returns a clone of ``variable`` that is annualized for the period ``annualization_period``. When annualized, a variable's formula is only called for a January calculation, and the results for other months are assumed to be identical. """ - def make_annual_formula(original_formula, annualization_period = None): - + def make_annual_formula(original_formula, annualization_period=None): def annual_formula(population, period, parameters): - if period.start.month != 1 and (annualization_period is None or annualization_period.contains(period)): + if period.start.month != 1 and ( + annualization_period is None or annualization_period.contains(period) + ): return population(variable.name, period.this_year.first_month) if original_formula.__code__.co_argcount == 2: return original_formula(population, period) @@ -26,22 +27,29 @@ def annual_formula(population, period, parameters): return annual_formula new_variable = variable.clone() - new_variable.formulas = sortedcontainers.sorteddict.SortedDict({ - key: make_annual_formula(formula, annualization_period) - for key, formula in variable.formulas.items() - }) + new_variable.formulas = sortedcontainers.sorteddict.SortedDict( + { + key: make_annual_formula(formula, annualization_period) + for key, formula in variable.formulas.items() + }, + ) return new_variable def get_neutralized_variable(variable): - """ - Return a new neutralized variable (to be used by reforms). + """Return a new neutralized variable (to be used by reforms). A neutralized variable always returns its default value, and does not cache anything. """ result = variable.clone() result.is_neutralized = True - result.label = '[Neutralized]' if variable.label is None else '[Neutralized] {}'.format(variable.label), + result.label = ( + ( + "[Neutralized]" + if variable.label is None + else f"[Neutralized] {variable.label}" + ), + ) return result diff --git a/openfisca_core/variables/tests/__init__.py b/openfisca_core/variables/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/openfisca_core/variables/tests/test_definition_period.py b/openfisca_core/variables/tests/test_definition_period.py new file mode 100644 index 0000000000..8ef9bfaa87 --- /dev/null +++ b/openfisca_core/variables/tests/test_definition_period.py @@ -0,0 +1,43 @@ +import pytest + +from openfisca_core import periods +from openfisca_core.variables import Variable + + +@pytest.fixture +def variable(persons): + class TestVariable(Variable): + value_type = float + entity = persons + + return TestVariable + + +def test_weekday_variable(variable) -> None: + variable.definition_period = periods.WEEKDAY + assert variable() + + +def test_week_variable(variable) -> None: + variable.definition_period = periods.WEEK + assert variable() + + +def test_day_variable(variable) -> None: + variable.definition_period = periods.DAY + assert variable() + + +def test_month_variable(variable) -> None: + variable.definition_period = periods.MONTH + assert variable() + + +def test_year_variable(variable) -> None: + variable.definition_period = periods.YEAR + assert variable() + + +def test_eternity_variable(variable) -> None: + variable.definition_period = periods.ETERNITY + assert variable() diff --git a/openfisca_core/variables/variable.py b/openfisca_core/variables/variable.py index acfeb9fe70..926e4c59c1 100644 --- a/openfisca_core/variables/variable.py +++ b/openfisca_core/variables/variable.py @@ -1,22 +1,24 @@ +from __future__ import annotations + +from typing import NoReturn + import datetime -import inspect import re import textwrap -import sortedcontainers import numpy +import sortedcontainers -from openfisca_core import periods, tools -from openfisca_core.entities import Entity +from openfisca_core import commons, periods, types as t +from openfisca_core.entities import Entity, GroupEntity from openfisca_core.indexed_enums import Enum, EnumArray -from openfisca_core.periods import Period +from openfisca_core.periods import DateUnit, Period from . import config, helpers class Variable: - """ - A `variable `_ of the legislation. + """A `variable `_ of the legislation. Main attributes: @@ -34,7 +36,7 @@ class Variable: .. attribute:: definition_period - `Period `_ the variable is defined for. Possible value: ``MONTH``, ``YEAR``, ``ETERNITY``. + `Period `_ the variable is defined for. Possible value: ``DateUnit.DAY``, ``DateUnit.MONTH``, ``DateUnit.YEAR``, ``DateUnit.ETERNITY``. .. attribute:: formulas @@ -95,64 +97,137 @@ class Variable: Free multilines text field describing the variable context and usage. """ - def __init__(self, baseline_variable = None): + __name__: str + + def __init__(self, baseline_variable=None) -> None: self.name = self.__class__.__name__ attr = { - name: value for name, value in self.__class__.__dict__.items() - if not name.startswith('__')} + name: value + for name, value in self.__class__.__dict__.items() + if not name.startswith("__") + } self.baseline_variable = baseline_variable - self.value_type = self.set(attr, 'value_type', required = True, allowed_values = config.VALUE_TYPES.keys()) - self.dtype = config.VALUE_TYPES[self.value_type]['dtype'] - self.json_type = config.VALUE_TYPES[self.value_type]['json_type'] + self.value_type = self.set( + attr, + "value_type", + required=True, + allowed_values=config.VALUE_TYPES.keys(), + ) + self.dtype = config.VALUE_TYPES[self.value_type]["dtype"] + self.json_type = config.VALUE_TYPES[self.value_type]["json_type"] if self.value_type == Enum: - self.possible_values = self.set(attr, 'possible_values', required = True, setter = self.set_possible_values) + self.possible_values = self.set( + attr, + "possible_values", + required=True, + setter=self.set_possible_values, + ) if self.value_type == str: - self.max_length = self.set(attr, 'max_length', allowed_type = int) + self.max_length = self.set(attr, "max_length", allowed_type=int) if self.max_length: - self.dtype = '|S{}'.format(self.max_length) + self.dtype = f"|S{self.max_length}" if self.value_type == Enum: - self.default_value = self.set(attr, 'default_value', allowed_type = self.possible_values, required = True) + self.default_value = self.set( + attr, + "default_value", + allowed_type=self.possible_values, + required=True, + ) else: - self.default_value = self.set(attr, 'default_value', allowed_type = self.value_type, default = config.VALUE_TYPES[self.value_type].get('default')) - self.entity = self.set(attr, 'entity', required = True, setter = self.set_entity) - self.definition_period = self.set(attr, 'definition_period', required = True, allowed_values = (periods.DAY, periods.MONTH, periods.YEAR, periods.ETERNITY)) - self.label = self.set(attr, 'label', allowed_type = str, setter = self.set_label) - self.end = self.set(attr, 'end', allowed_type = str, setter = self.set_end) - self.reference = self.set(attr, 'reference', setter = self.set_reference) - self.cerfa_field = self.set(attr, 'cerfa_field', allowed_type = (str, dict)) - self.unit = self.set(attr, 'unit', allowed_type = str) - self.documentation = self.set(attr, 'documentation', allowed_type = str, setter = self.set_documentation) - self.set_input = self.set_set_input(attr.pop('set_input', None)) - self.calculate_output = self.set_calculate_output(attr.pop('calculate_output', None)) - self.is_period_size_independent = self.set(attr, 'is_period_size_independent', allowed_type = bool, default = config.VALUE_TYPES[self.value_type]['is_period_size_independent']) - - formulas_attr, unexpected_attrs = helpers._partition(attr, lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX)) + self.default_value = self.set( + attr, + "default_value", + allowed_type=self.value_type, + default=config.VALUE_TYPES[self.value_type].get("default"), + ) + self.entity = self.set(attr, "entity", required=True, setter=self.set_entity) + self.definition_period = self.set( + attr, + "definition_period", + required=True, + allowed_values=DateUnit, + ) + self.label = self.set(attr, "label", allowed_type=str, setter=self.set_label) + self.end = self.set(attr, "end", allowed_type=str, setter=self.set_end) + self.reference = self.set(attr, "reference", setter=self.set_reference) + self.cerfa_field = self.set(attr, "cerfa_field", allowed_type=(str, dict)) + self.unit = self.set(attr, "unit", allowed_type=str) + self.documentation = self.set( + attr, + "documentation", + allowed_type=str, + setter=self.set_documentation, + ) + self.set_input = self.set_set_input(attr.pop("set_input", None)) + self.calculate_output = self.set_calculate_output( + attr.pop("calculate_output", None), + ) + self.is_period_size_independent = self.set( + attr, + "is_period_size_independent", + allowed_type=bool, + default=config.VALUE_TYPES[self.value_type]["is_period_size_independent"], + ) + + self.introspection_data = self.set( + attr, + "introspection_data", + ) + + formulas_attr, unexpected_attrs = helpers._partition( + attr, + lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX), + ) self.formulas = self.set_formulas(formulas_attr) if unexpected_attrs: + msg = 'Unexpected attributes in definition of variable "{}": {!r}'.format( + self.name, + ", ".join(sorted(unexpected_attrs.keys())), + ) raise ValueError( - 'Unexpected attributes in definition of variable "{}": {!r}' - .format(self.name, ', '.join(sorted(unexpected_attrs.keys())))) + msg, + ) self.is_neutralized = False # ----- Setters used to build the variable ----- # - def set(self, attributes, attribute_name, required = False, allowed_values = None, allowed_type = None, setter = None, default = None): + def set( + self, + attributes, + attribute_name, + required=False, + allowed_values=None, + allowed_type=None, + setter=None, + default=None, + ): value = attributes.pop(attribute_name, None) if value is None and self.baseline_variable: return getattr(self.baseline_variable, attribute_name) if required and value is None: - raise ValueError("Missing attribute '{}' in definition of variable '{}'.".format(attribute_name, self.name)) + msg = f"Missing attribute '{attribute_name}' in definition of variable '{self.name}'." + raise ValueError( + msg, + ) if allowed_values is not None and value not in allowed_values: - raise ValueError("Invalid value '{}' for attribute '{}' in variable '{}'. Allowed values are '{}'." - .format(value, attribute_name, self.name, allowed_values)) - if allowed_type is not None and value is not None and not isinstance(value, allowed_type): + msg = f"Invalid value '{value}' for attribute '{attribute_name}' in variable '{self.name}'. Allowed values are '{allowed_values}'." + raise ValueError( + msg, + ) + if ( + allowed_type is not None + and value is not None + and not isinstance(value, allowed_type) + ): if allowed_type == float and isinstance(value, int): value = float(value) else: - raise ValueError("Invalid value '{}' for attribute '{}' in variable '{}'. Must be of type '{}'." - .format(value, attribute_name, self.name, allowed_type)) + msg = f"Invalid value '{value}' for attribute '{attribute_name}' in variable '{self.name}'. Must be of type '{allowed_type}'." + raise ValueError( + msg, + ) if setter is not None: value = setter(value) if value is None and default is not None: @@ -160,26 +235,39 @@ def set(self, attributes, attribute_name, required = False, allowed_values = Non return value def set_entity(self, entity): - if not isinstance(entity, Entity): - raise ValueError(f"Invalid value '{entity}' for attribute 'entity' in variable '{self.name}'. Must be an instance of Entity.") + if not isinstance(entity, (Entity, GroupEntity)): + msg = ( + f"Invalid value '{entity}' for attribute 'entity' in variable " + f"'{self.name}'. Must be an instance of Entity or GroupEntity." + ) + raise ValueError( + msg, + ) return entity def set_possible_values(self, possible_values): if not issubclass(possible_values, Enum): - raise ValueError("Invalid value '{}' for attribute 'possible_values' in variable '{}'. Must be a subclass of {}." - .format(possible_values, self.name, Enum)) + msg = f"Invalid value '{possible_values}' for attribute 'possible_values' in variable '{self.name}'. Must be a subclass of {Enum}." + raise ValueError( + msg, + ) return possible_values def set_label(self, label): if label: return label + return None def set_end(self, end): if end: try: - return datetime.datetime.strptime(end, '%Y-%m-%d').date() + return datetime.datetime.strptime(end, "%Y-%m-%d").date() except ValueError: - raise ValueError("Incorrect 'end' attribute format in '{}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {}".format(self.name, end)) + msg = f"Incorrect 'end' attribute format in '{self.name}'. 'YYYY-MM-DD' expected where YYYY, MM and DD are year, month and day. Found: {end}" + raise ValueError( + msg, + ) + return None def set_reference(self, reference): if reference: @@ -190,19 +278,24 @@ def set_reference(self, reference): elif isinstance(reference, tuple): reference = list(reference) else: - raise TypeError('The reference of the variable {} is a {} instead of a String or a List of Strings.'.format(self.name, type(reference))) + msg = f"The reference of the variable {self.name} is a {type(reference)} instead of a String or a List of Strings." + raise TypeError( + msg, + ) for element in reference: if not isinstance(element, str): + msg = f"The reference of the variable {self.name} is a {type(reference)} instead of a String or a List of Strings." raise TypeError( - 'The reference of the variable {} is a {} instead of a String or a List of Strings.'.format( - self.name, type(reference))) + msg, + ) return reference def set_documentation(self, documentation): if documentation: return textwrap.dedent(documentation) + return None def set_set_input(self, set_input): if not set_input and self.baseline_variable: @@ -220,25 +313,29 @@ def set_formulas(self, formulas_attr): starting_date = self.parse_formula_name(formula_name) if self.end is not None and starting_date > self.end: - raise ValueError('You declared that "{}" ends on "{}", but you wrote a formula to calculate it from "{}" ({}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.' - .format(self.name, self.end, starting_date, formula_name)) + msg = f'You declared that "{self.name}" ends on "{self.end}", but you wrote a formula to calculate it from "{starting_date}" ({formula_name}). The "end" attribute of a variable must be posterior to the start dates of all its formulas.' + raise ValueError( + msg, + ) formulas[str(starting_date)] = formula # If the variable is reforming a baseline variable, keep the formulas from the latter when they are not overridden by new formulas. if self.baseline_variable is not None: first_reform_formula_date = formulas.peekitem(0)[0] if formulas else None - formulas.update({ - baseline_start_date: baseline_formula - for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items() - if first_reform_formula_date is None or baseline_start_date < first_reform_formula_date - }) + formulas.update( + { + baseline_start_date: baseline_formula + for baseline_start_date, baseline_formula in self.baseline_variable.formulas.items() + if first_reform_formula_date is None + or baseline_start_date < first_reform_formula_date + }, + ) return formulas def parse_formula_name(self, attribute_name): - """ - Returns the starting date of a formula based on its name. + """Returns the starting date of a formula based on its name. Valid dated name formats are : 'formula', 'formula_YYYY', 'formula_YYYY_MM' and 'formula_YYYY_MM_DD' where YYYY, MM and DD are a year, month and day. @@ -248,76 +345,69 @@ def parse_formula_name(self, attribute_name): - `formula_YYYY_MM` is `YYYY-MM-01` """ - def raise_error(): + def raise_error() -> NoReturn: + msg = f'Unrecognized formula name in variable "{self.name}". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: "{attribute_name}".' raise ValueError( - 'Unrecognized formula name in variable "{}". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: "{}".' - .format(self.name, attribute_name)) + msg, + ) if attribute_name == config.FORMULA_NAME_PREFIX: return datetime.date.min - FORMULA_REGEX = r'formula_(\d{4})(?:_(\d{2}))?(?:_(\d{2}))?$' # YYYY or YYYY_MM or YYYY_MM_DD + FORMULA_REGEX = r"formula_(\d{4})(?:_(\d{2}))?(?:_(\d{2}))?$" # YYYY or YYYY_MM or YYYY_MM_DD match = re.match(FORMULA_REGEX, attribute_name) if not match: raise_error() - date_str = '-'.join([match.group(1), match.group(2) or '01', match.group(3) or '01']) + date_str = "-".join( + [match.group(1), match.group(2) or "01", match.group(3) or "01"], + ) try: - return datetime.datetime.strptime(date_str, '%Y-%m-%d').date() + return datetime.datetime.strptime(date_str, "%Y-%m-%d").date() except ValueError: # formula_2005_99_99 for instance raise_error() # ----- Methods ----- # def is_input_variable(self): - """ - Returns True if the variable is an input variable. - """ + """Returns True if the variable is an input variable.""" return len(self.formulas) == 0 @classmethod - def get_introspection_data(cls, tax_benefit_system): - """ - Get instrospection data about the code of the variable. - - :returns: (comments, source file path, source code, start line number) - :rtype: tuple - - """ - comments = inspect.getcomments(cls) - - # Handle dynamically generated variable classes or Jupyter Notebooks, which have no source. + def get_introspection_data(cls): try: - absolute_file_path = inspect.getsourcefile(cls) - except TypeError: - source_file_path = None - else: - source_file_path = absolute_file_path.replace(tax_benefit_system.get_package_metadata()['location'], '') - try: - source_lines, start_line_number = inspect.getsourcelines(cls) - source_code = textwrap.dedent(''.join(source_lines)) - except (IOError, TypeError): - source_code, start_line_number = None, None + return cls.introspection_data + except AttributeError: + return "", None, 0 - return comments, source_file_path, source_code, start_line_number + def get_formula( + self, + period: None | t.Instant | t.Period | str | int = None, + ) -> None | t.Formula: + """Returns the formula to compute the variable at the given period. - def get_formula(self, period = None): - """ - Returns the formula used to compute the variable at the given period. + If no period is given and the variable has several formulas, the method + returns the oldest formula. - If no period is given and the variable has several formula, return the oldest formula. + Args: + period: The period to get the formula. - :returns: Formula used to compute the variable - :rtype: callable + Returns: + Formula used to compute the variable. """ + instant: None | t.Instant if not self.formulas: return None if period is None: - return self.formulas.peekitem(index = 0)[1] # peekitem gets the 1st key-value tuple (the oldest start_date and formula). Return the formula. + return self.formulas.peekitem( + index=0, + )[ + 1 + ] # peekitem gets the 1st key-value tuple (the oldest start_date and formula). Return the formula. if isinstance(period, Period): instant = period.start @@ -327,19 +417,22 @@ def get_formula(self, period = None): except ValueError: instant = periods.instant(period) + if instant is None: + return None + if self.end and instant.date > self.end: return None - instant = str(instant) + instant_str = str(instant) + for start_date in reversed(self.formulas): - if start_date <= instant: + if start_date <= instant_str: return self.formulas[start_date] return None def clone(self): - clone = self.__class__() - return clone + return self.__class__() def check_set_value(self, value): if self.value_type == Enum and isinstance(value, str): @@ -347,35 +440,39 @@ def check_set_value(self, value): value = self.possible_values[value].index except KeyError: possible_values = [item.name for item in self.possible_values] + msg = "'{}' is not a known value for '{}'. Possible values are ['{}'].".format( + value, + self.name, + "', '".join(possible_values), + ) raise ValueError( - "'{}' is not a known value for '{}'. Possible values are ['{}'].".format( - value, self.name, "', '".join(possible_values)) - ) + msg, + ) if self.value_type in (float, int) and isinstance(value, str): try: - value = tools.eval_expression(value) + value = commons.eval_expression(value) except SyntaxError: + msg = f"I couldn't understand '{value}' as a value for '{self.name}'" raise ValueError( - "I couldn't understand '{}' as a value for '{}'".format( - value, self.name) - ) + msg, + ) try: - value = numpy.array([value], dtype = self.dtype)[0] + value = numpy.array([value], dtype=self.dtype)[0] except (TypeError, ValueError): - if (self.value_type == datetime.date): - error_message = "Can't deal with date: '{}'.".format(value) + if self.value_type == datetime.date: + error_message = f"Can't deal with date: '{value}'." else: - error_message = "Can't deal with value: expected type {}, received '{}'.".format(self.json_type, value) + error_message = f"Can't deal with value: expected type {self.json_type}, received '{value}'." raise ValueError(error_message) - except (OverflowError): - error_message = "Can't deal with value: '{}', it's too large for type '{}'.".format(value, self.json_type) + except OverflowError: + error_message = f"Can't deal with value: '{value}', it's too large for type '{self.json_type}'." raise ValueError(error_message) return value def default_array(self, array_size): - array = numpy.empty(array_size, dtype = self.dtype) + array = numpy.empty(array_size, dtype=self.dtype) if self.value_type == Enum: array.fill(self.default_value.index) return EnumArray(array, self.possible_values) diff --git a/openfisca_core/warnings/libyaml_warning.py b/openfisca_core/warnings/libyaml_warning.py index 361a1688ad..7ea797b667 100644 --- a/openfisca_core/warnings/libyaml_warning.py +++ b/openfisca_core/warnings/libyaml_warning.py @@ -1,5 +1,2 @@ class LibYAMLWarning(UserWarning): - """ - Custom warning for LibYAML not installed. - """ - pass + """Custom warning for LibYAML not installed.""" diff --git a/openfisca_core/warnings/memory_warning.py b/openfisca_core/warnings/memory_warning.py index 8fcb1f46f4..23e82bf3e0 100644 --- a/openfisca_core/warnings/memory_warning.py +++ b/openfisca_core/warnings/memory_warning.py @@ -1,5 +1,2 @@ class MemoryConfigWarning(UserWarning): - """ - Custom warning for MemoryConfig. - """ - pass + """Custom warning for MemoryConfig.""" diff --git a/openfisca_core/warnings/tempfile_warning.py b/openfisca_core/warnings/tempfile_warning.py index cf2b9947ac..9f4aad3820 100644 --- a/openfisca_core/warnings/tempfile_warning.py +++ b/openfisca_core/warnings/tempfile_warning.py @@ -1,5 +1,2 @@ class TempfileWarning(UserWarning): - """ - Custom warning when using a tempfile on disk. - """ - pass + """Custom warning when using a tempfile on disk.""" diff --git a/openfisca_tasks/install.mk b/openfisca_tasks/install.mk new file mode 100644 index 0000000000..bb844b9d56 --- /dev/null +++ b/openfisca_tasks/install.mk @@ -0,0 +1,21 @@ +## Uninstall project's dependencies. +uninstall: + @$(call print_help,$@:) + @python -m pip freeze | grep -v "^-e" | sed "s/@.*//" | xargs pip uninstall -y + +## Install project's overall dependencies +install-deps: + @$(call print_help,$@:) + @python -m pip install --upgrade pip + +## Install project's development dependencies. +install-edit: + @$(call print_help,$@:) + @python -m pip install --upgrade --editable ".[dev]" + +## Delete builds and compiled python files. +clean: + @$(call print_help,$@:) + @ls -d * | grep "build\|dist" | xargs rm -rf + @find . -name "__pycache__" | xargs rm -rf + @find . -name "*.pyc" | xargs rm -rf diff --git a/openfisca_tasks/lint.mk b/openfisca_tasks/lint.mk new file mode 100644 index 0000000000..646cf76d70 --- /dev/null +++ b/openfisca_tasks/lint.mk @@ -0,0 +1,53 @@ +## Lint the codebase. +lint: check-syntax-errors check-style lint-doc + @$(call print_pass,$@:) + +## Compile python files to check for syntax errors. +check-syntax-errors: . + @$(call print_help,$@:) + @python -m compileall -q $? + @$(call print_pass,$@:) + +## Run linters to check for syntax and style errors. +check-style: $(shell git ls-files "*.py" "*.pyi") + @$(call print_help,$@:) + @python -m isort --check $? + @python -m black --check $? + @python -m flake8 $? + @$(call print_pass,$@:) + +## Run linters to check for syntax and style errors in the doc. +lint-doc: \ + lint-doc-commons \ + lint-doc-entities \ + ; + +## Run linters to check for syntax and style errors in the doc. +lint-doc-%: + @## These checks are exclusively related to doc/strings/test. + @## + @## They can be integrated into setup.cfg once all checks pass. + @## The reason they're here is because otherwise we wouldn't be + @## able to integrate documentation improvements progresively. + @## + @$(call print_help,$(subst $*,%,$@:)) + @python -m flake8 --select=D101,D102,D103,DAR openfisca_core/$* + @python -m pylint openfisca_core/$* + @$(call print_pass,$@:) + +## Run static type checkers for type errors. +check-types: + @$(call print_help,$@:) + @python -m mypy \ + openfisca_core/commons \ + openfisca_core/entities \ + openfisca_core/periods \ + openfisca_core/types.py + @$(call print_pass,$@:) + +## Run code formatters to correct style errors. +format-style: $(shell git ls-files "*.py" "*.pyi") + @$(call print_help,$@:) + @python -m isort $? + @python -m black $? + @$(call print_pass,$@:) diff --git a/openfisca_tasks/publish.mk b/openfisca_tasks/publish.mk new file mode 100644 index 0000000000..37e599b63f --- /dev/null +++ b/openfisca_tasks/publish.mk @@ -0,0 +1,25 @@ +.PHONY: build + +## Install project's build dependencies. +install-dist: + @$(call print_help,$@:) + @python -m pip install .[ci,dev] + @$(call print_pass,$@:) + +## Build & install openfisca-core for deployment and publishing. +build: + @## This allows us to be sure tests are run against the packaged version + @## of openfisca-core, the same we put in the hands of users and reusers. + @$(call print_help,$@:) + @python -m build + @python -m pip uninstall --yes openfisca-core + @find dist -name "*.whl" -exec python -m pip install --no-deps {} \; + @$(call print_pass,$@:) + +## Upload to PyPi. +publish: + @$(call print_help,$@:) + @python -m twine upload dist/* --username $PYPI_USERNAME --password $PYPI_TOKEN + @git tag `python setup.py --version` + @git push --tags # update the repository version + @$(call print_pass,$@:) diff --git a/openfisca_tasks/serve.mk b/openfisca_tasks/serve.mk new file mode 100644 index 0000000000..efad0be6cb --- /dev/null +++ b/openfisca_tasks/serve.mk @@ -0,0 +1,6 @@ +## Serve the openfisca Web API. +api: + @$(call print_help,$@:) + @openfisca serve \ + --country-package openfisca_country_template \ + --extensions openfisca_extension_template diff --git a/openfisca_tasks/test_code.mk b/openfisca_tasks/test_code.mk new file mode 100644 index 0000000000..8878fe9d33 --- /dev/null +++ b/openfisca_tasks/test_code.mk @@ -0,0 +1,75 @@ +## The openfisca command module. +openfisca = openfisca_core.scripts.openfisca_command + +## The path to the templates' tests. +ifeq ($(OS),Windows_NT) + tests = $(shell python -c "import os, $(1); print(repr(os.path.join($(1).__path__[0], 'tests')))") +else + tests = $(shell python -c "import $(1); print($(1).__path__[0])")/tests +endif + +## Run all tasks required for testing. +install: install-deps install-edit install-test + +## Enable regression testing with template repositories. +install-test: + @$(call print_help,$@:) + @python -m pip install --upgrade --no-deps openfisca-country-template + @python -m pip install --upgrade --no-deps openfisca-extension-template + +## Run openfisca-core & country/extension template tests. +test-code: test-core test-country test-extension + @## Usage: + @## + @## make test [pytest_args="--ARG"] [openfisca_args="--ARG"] + @## + @## Examples: + @## + @## make test + @## make test pytest_args="--exitfirst" + @## make test openfisca_args="--performance" + @## make test pytest_args="--exitfirst" openfisca_args="--performance" + @## + @$(call print_pass,$@:) + +## Run openfisca-core tests. +test-core: $(shell git ls-files "*test_*.py") + @$(call print_help,$@:) + @python -m pytest --capture=no --xdoctest --xdoctest-verbose=0 \ + openfisca_core/commons \ + openfisca_core/entities \ + openfisca_core/holders \ + openfisca_core/periods \ + openfisca_core/projectors + @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ + python -m coverage run -m ${openfisca} test \ + $? \ + ${openfisca_args} + @$(call print_pass,$@:) + +## Run country-template tests. +test-country: + @$(call print_help,$@:) + @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ + python -m ${openfisca} test \ + $(call tests,"openfisca_country_template") \ + --country-package openfisca_country_template \ + ${openfisca_args} + @$(call print_pass,$@:) + +## Run extension-template tests. +test-extension: + @$(call print_help,$@:) + @PYTEST_ADDOPTS="$${PYTEST_ADDOPTS} ${pytest_args}" \ + python -m ${openfisca} test \ + $(call tests,"openfisca_extension_template") \ + --country-package openfisca_country_template \ + --extensions openfisca_extension_template \ + ${openfisca_args} + @$(call print_pass,$@:) + +## Print the coverage report. +test-cov: + @$(call print_help,$@:) + @python -m coverage report + @$(call print_pass,$@:) diff --git a/openfisca_web_api/app.py b/openfisca_web_api/app.py index e2244e9ba2..a76f255a0c 100644 --- a/openfisca_web_api/app.py +++ b/openfisca_web_api/app.py @@ -1,50 +1,57 @@ -# -*- coding: utf-8 -*- - import logging import os import traceback -from openfisca_core.errors import SituationParsingError, PeriodMismatchError -from openfisca_web_api.loader import build_data -from openfisca_web_api.errors import handle_import_error +from openfisca_core.errors import PeriodMismatchError, SituationParsingError from openfisca_web_api import handlers +from openfisca_web_api.errors import handle_import_error +from openfisca_web_api.loader import build_data try: - from flask import Flask, jsonify, abort, request, make_response + import werkzeug.exceptions + from flask import Flask, abort, jsonify, make_response, redirect, request from flask_cors import CORS from werkzeug.middleware.proxy_fix import ProxyFix - import werkzeug.exceptions except ImportError as error: handle_import_error(error) -log = logging.getLogger('gunicorn.error') +log = logging.getLogger("gunicorn.error") def init_tracker(url, idsite, tracker_token): try: from openfisca_tracker.piwik import PiwikTracker + tracker = PiwikTracker(url, idsite, tracker_token) - info = os.linesep.join(['You chose to activate the `tracker` module. ', - 'Tracking data will be sent to: ' + url, - 'For more information, see .']) + info = os.linesep.join( + [ + "You chose to activate the `tracker` module. ", + "Tracking data will be sent to: " + url, + "For more information, see .", + ], + ) log.info(info) return tracker except ImportError: - message = os.linesep.join([traceback.format_exc(), - 'You chose to activate the `tracker` module, but it is not installed.', - 'For more information, see .']) - log.warn(message) - - -def create_app(tax_benefit_system, - tracker_url = None, - tracker_idsite = None, - tracker_token = None, - welcome_message = None, - ): - + message = os.linesep.join( + [ + traceback.format_exc(), + "You chose to activate the `tracker` module, but it is not installed.", + "For more information, see .", + ], + ) + log.warning(message) + + +def create_app( + tax_benefit_system, + tracker_url=None, + tracker_idsite=None, + tracker_token=None, + welcome_message=None, +): if not tracker_url or not tracker_idsite: tracker = None else: @@ -52,88 +59,108 @@ def create_app(tax_benefit_system, app = Flask(__name__) # Fix request.remote_addr to get the real client IP address - app.wsgi_app = ProxyFix(app.wsgi_app, x_for = 1, x_host = 1) - CORS(app, origins = '*') + app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_host=1) + CORS(app, origins="*") - app.config['JSON_AS_ASCII'] = False # When False, lets jsonify encode to utf-8 + app.config["JSON_AS_ASCII"] = False # When False, lets jsonify encode to utf-8 app.url_map.strict_slashes = False # Accept url like /parameters/ app.url_map.merge_slashes = False # Do not eliminate // in paths - app.config['JSON_SORT_KEYS'] = False # Don't sort JSON keys in the Web API + app.config["JSON_SORT_KEYS"] = False # Don't sort JSON keys in the Web API data = build_data(tax_benefit_system) DEFAULT_WELCOME_MESSAGE = "This is the root of an OpenFisca Web API. To learn how to use it, check the general documentation (https://openfisca.org/doc/) and the OpenAPI specification of this instance ({}spec)." - @app.route('/') - def get_root(): - return jsonify({ - 'welcome': welcome_message or DEFAULT_WELCOME_MESSAGE.format(request.host_url) - }), 300 + @app.before_request + def before_request(): + if request.path != "/" and request.path.endswith("/"): + return redirect(request.path[:-1]) + return None - @app.route('/parameters') + @app.route("/") + def get_root(): + return ( + jsonify( + { + "welcome": welcome_message + or DEFAULT_WELCOME_MESSAGE.format(request.host_url), + }, + ), + 300, + ) + + @app.route("/parameters") def get_parameters(): parameters = { - parameter['id']: { - 'description': parameter['description'], - 'href': '{}parameter/{}'.format(request.host_url, name) - } - for name, parameter in data['parameters'].items() - if parameter.get('subparams') is None # For now and for backward compat, don't show nodes in overview + parameter["id"]: { + "description": parameter["description"], + "href": f"{request.host_url}parameter/{name}", } + for name, parameter in data["parameters"].items() + if parameter.get("subparams") + is None # For now and for backward compat, don't show nodes in overview + } return jsonify(parameters) - @app.route('/parameter/') + @app.route("/parameter/") def get_parameter(parameter_id): - parameter = data['parameters'].get(parameter_id) + parameter = data["parameters"].get(parameter_id) if parameter is None: # Try legacy route - parameter_new_id = parameter_id.replace('.', '/') - parameter = data['parameters'].get(parameter_new_id) + parameter_new_id = parameter_id.replace(".", "/") + parameter = data["parameters"].get(parameter_new_id) if parameter is None: raise abort(404) return jsonify(parameter) - @app.route('/variables') + @app.route("/variables") def get_variables(): variables = { name: { - 'description': variable['description'], - 'href': '{}variable/{}'.format(request.host_url, name) - } - for name, variable in data['variables'].items() + "description": variable["description"], + "href": f"{request.host_url}variable/{name}", } + for name, variable in data["variables"].items() + } return jsonify(variables) - @app.route('/variable/') + @app.route("/variable/") def get_variable(id): - variable = data['variables'].get(id) + variable = data["variables"].get(id) if variable is None: raise abort(404) return jsonify(variable) - @app.route('/entities') + @app.route("/entities") def get_entities(): - return jsonify(data['entities']) + return jsonify(data["entities"]) - @app.route('/spec') + @app.route("/spec") def get_spec(): - return jsonify({ - **data['openAPI_spec'], - **{'host': request.host}, - **{'schemes': [request.environ['wsgi.url_scheme']]} - }) - - def handle_invalid_json(error): - json_response = jsonify({ - 'error': 'Invalid JSON: {}'.format(error.args[0]), - }) + scheme = request.environ["wsgi.url_scheme"] + host = request.host + url = f"{scheme}://{host}" + + return jsonify( + { + **data["openAPI_spec"], + "servers": [{"url": url}], + }, + ) + + def handle_invalid_json(error) -> None: + json_response = jsonify( + { + "error": f"Invalid JSON: {error.args[0]}", + }, + ) abort(make_response(json_response, 400)) - @app.route('/calculate', methods=['POST']) + @app.route("/calculate", methods=["POST"]) def calculate(): - tax_benefit_system = data['tax_benefit_system'] + tax_benefit_system = data["tax_benefit_system"] request.on_json_loading_failed = handle_invalid_json input_data = request.get_json() try: @@ -141,12 +168,17 @@ def calculate(): except (SituationParsingError, PeriodMismatchError) as e: abort(make_response(jsonify(e.error), e.code or 400)) except (UnicodeEncodeError, UnicodeDecodeError) as e: - abort(make_response(jsonify({"error": "'" + e[1] + "' is not a valid ASCII value."}), 400)) + abort( + make_response( + jsonify({"error": "'" + e[1] + "' is not a valid ASCII value."}), + 400, + ), + ) return jsonify(result) - @app.route('/trace', methods=['POST']) + @app.route("/trace", methods=["POST"]) def trace(): - tax_benefit_system = data['tax_benefit_system'] + tax_benefit_system = data["tax_benefit_system"] request.on_json_loading_failed = handle_invalid_json input_data = request.get_json() try: @@ -157,25 +189,28 @@ def trace(): @app.after_request def apply_headers(response): - response.headers.extend({ - 'Country-Package': data['country_package_metadata']['name'], - 'Country-Package-Version': data['country_package_metadata']['version'] - }) + response.headers.extend( + { + "Country-Package": data["country_package_metadata"]["name"], + "Country-Package-Version": data["country_package_metadata"]["version"], + }, + ) return response @app.after_request def track_requests(response): - if tracker: - if request.headers.get('dnt'): + if request.headers.get("dnt"): source_ip = "" - elif request.headers.get('X-Forwarded-For'): - source_ip = request.headers['X-Forwarded-For'].split(', ')[0] + elif request.headers.get("X-Forwarded-For"): + source_ip = request.headers["X-Forwarded-For"].split(", ")[0] else: source_ip = request.remote_addr - api_version = "{}-{}".format(data['country_package_metadata']['name'], - data['country_package_metadata']['version']) + api_version = "{}-{}".format( + data["country_package_metadata"]["name"], + data["country_package_metadata"]["version"], + ) tracker.track(request.url, source_ip, api_version, request.path) return response diff --git a/openfisca_web_api/errors.py b/openfisca_web_api/errors.py index 96c95a6874..ac93ebd833 100644 --- a/openfisca_web_api/errors.py +++ b/openfisca_web_api/errors.py @@ -1,9 +1,12 @@ -# -*- coding: utf-8 -*- +from typing import NoReturn import logging -log = logging.getLogger('gunicorn.error') +log = logging.getLogger("gunicorn.error") -def handle_import_error(error): - raise ImportError("OpenFisca is missing some dependencies to run the Web API: '{}'. To install them, run `pip install openfisca_core[web-api]`.".format(error)) +def handle_import_error(error) -> NoReturn: + msg = f"OpenFisca is missing some dependencies to run the Web API: '{error}'. To install them, run `pip install openfisca_core[web-api]`." + raise ImportError( + msg, + ) diff --git a/openfisca_web_api/handlers.py b/openfisca_web_api/handlers.py index 9c8826772c..2f6fc4403a 100644 --- a/openfisca_web_api/handlers.py +++ b/openfisca_web_api/handlers.py @@ -1,20 +1,24 @@ -# -*- coding: utf-8 -*- +import dpath.util -import dpath - -from openfisca_core.simulation_builder import SimulationBuilder from openfisca_core.indexed_enums import Enum +from openfisca_core.simulation_builder import SimulationBuilder -def calculate(tax_benefit_system, input_data): +def calculate(tax_benefit_system, input_data: dict) -> dict: + """Returns the input_data where the None values are replaced by the calculated values.""" simulation = SimulationBuilder().build_from_entities(tax_benefit_system, input_data) - - requested_computations = dpath.util.search(input_data, '*/*/*/*', afilter = lambda t: t is None, yielded = True) - computation_results = {} - + requested_computations = dpath.util.search( + input_data, + "*/*/*/*", + afilter=lambda t: t is None, + yielded=True, + ) + computation_results: dict = {} for computation in requested_computations: - path = computation[0] - entity_plural, entity_id, variable_name, period = path.split('/') + path = computation[ + 0 + ] # format: entity_plural/entity_instance_id/openfisca_variable_name/period + entity_plural, entity_id, variable_name, period = path.split("/") variable = tax_benefit_system.get_variable(variable_name) result = simulation.calculate(variable_name, period) population = simulation.get_population(entity_plural) @@ -23,15 +27,39 @@ def calculate(tax_benefit_system, input_data): if variable.value_type == Enum: entity_result = result.decode()[entity_index].name elif variable.value_type == float: - entity_result = float(str(result[entity_index])) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way. + entity_result = float( + str(result[entity_index]), + ) # To turn the float32 into a regular float without adding confusing extra decimals. There must be a better way. elif variable.value_type == str: entity_result = str(result[entity_index]) else: entity_result = result.tolist()[entity_index] - - dpath.util.new(computation_results, path, entity_result) - - dpath.merge(input_data, computation_results) + # Don't use dpath.util.new, because there is a problem with dpath>=2.0 + # when we have a key that is numeric, like the year. + # See https://github.com/dpath-maintainers/dpath-python/issues/160 + if computation_results == {}: + computation_results = { + entity_plural: {entity_id: {variable_name: {period: entity_result}}}, + } + elif entity_plural in computation_results: + if entity_id in computation_results[entity_plural]: + if variable_name in computation_results[entity_plural][entity_id]: + computation_results[entity_plural][entity_id][variable_name][ + period + ] = entity_result + else: + computation_results[entity_plural][entity_id][variable_name] = { + period: entity_result, + } + else: + computation_results[entity_plural][entity_id] = { + variable_name: {period: entity_result}, + } + else: + computation_results[entity_plural] = { + entity_id: {variable_name: {period: entity_result}}, + } + dpath.util.merge(input_data, computation_results) return input_data @@ -41,11 +69,16 @@ def trace(tax_benefit_system, input_data): simulation.trace = True requested_calculations = [] - requested_computations = dpath.util.search(input_data, '*/*/*/*', afilter = lambda t: t is None, yielded = True) + requested_computations = dpath.util.search( + input_data, + "*/*/*/*", + afilter=lambda t: t is None, + yielded=True, + ) for computation in requested_computations: path = computation[0] - entity_plural, entity_id, variable_name, period = path.split('/') - requested_calculations.append(f"{variable_name}<{str(period)}>") + entity_plural, entity_id, variable_name, period = path.split("/") + requested_calculations.append(f"{variable_name}<{period!s}>") simulation.calculate(variable_name, period) trace = simulation.tracer.get_serialized_flat_trace() @@ -53,5 +86,5 @@ def trace(tax_benefit_system, input_data): return { "trace": trace, "entitiesDescription": simulation.describe_entities(), - "requestedCalculations": requested_calculations - } + "requestedCalculations": requested_calculations, + } diff --git a/openfisca_web_api/loader/__init__.py b/openfisca_web_api/loader/__init__.py index b86aefad57..8d9318d9ae 100644 --- a/openfisca_web_api/loader/__init__.py +++ b/openfisca_web_api/loader/__init__.py @@ -1,24 +1,22 @@ -# -*- coding: utf-8 -*- - - -from openfisca_web_api.loader.parameters import build_parameters -from openfisca_web_api.loader.variables import build_variables from openfisca_web_api.loader.entities import build_entities +from openfisca_web_api.loader.parameters import build_parameters from openfisca_web_api.loader.spec import build_openAPI_specification +from openfisca_web_api.loader.variables import build_variables def build_data(tax_benefit_system): country_package_metadata = tax_benefit_system.get_package_metadata() parameters = build_parameters(tax_benefit_system, country_package_metadata) variables = build_variables(tax_benefit_system, country_package_metadata) + entities = build_entities(tax_benefit_system) data = { - 'tax_benefit_system': tax_benefit_system, - 'country_package_metadata': tax_benefit_system.get_package_metadata(), - 'openAPI_spec': None, - 'parameters': parameters, - 'variables': variables, - 'entities': build_entities(tax_benefit_system), - } - data['openAPI_spec'] = build_openAPI_specification(data) + "tax_benefit_system": tax_benefit_system, + "country_package_metadata": country_package_metadata, + "openAPI_spec": None, + "parameters": parameters, + "variables": variables, + "entities": entities, + } + data["openAPI_spec"] = build_openAPI_specification(data) return data diff --git a/openfisca_web_api/loader/entities.py b/openfisca_web_api/loader/entities.py index 2f3194882a..98ce4e6fb9 100644 --- a/openfisca_web_api/loader/entities.py +++ b/openfisca_web_api/loader/entities.py @@ -1,39 +1,28 @@ -# -*- coding: utf-8 -*- - - def build_entities(tax_benefit_system): - entities = { - entity.key: build_entity(entity) - for entity in tax_benefit_system.entities - } - return entities + return {entity.key: build_entity(entity) for entity in tax_benefit_system.entities} def build_entity(entity): formatted_doc = entity.doc.strip() formatted_entity = { - 'plural': entity.plural, - 'description': entity.label, - 'documentation': formatted_doc - } + "plural": entity.plural, + "description": entity.label, + "documentation": formatted_doc, + } if not entity.is_person: - formatted_entity['roles'] = { - role.key: build_role(role) - for role in entity.roles - } + formatted_entity["roles"] = { + role.key: build_role(role) for role in entity.roles + } return formatted_entity def build_role(role): - formatted_role = { - 'plural': role.plural, - 'description': role.doc - } + formatted_role = {"plural": role.plural, "description": role.doc} if role.max: - formatted_role['max'] = role.max + formatted_role["max"] = role.max if role.subroles: - formatted_role['max'] = len(role.subroles) + formatted_role["max"] = len(role.subroles) return formatted_role diff --git a/openfisca_web_api/loader/parameters.py b/openfisca_web_api/loader/parameters.py index 23a5f738b5..193f12915f 100644 --- a/openfisca_web_api/loader/parameters.py +++ b/openfisca_web_api/loader/parameters.py @@ -1,4 +1,5 @@ -# -*- coding: utf-8 -*- +import functools +import operator from openfisca_core.parameters import Parameter, ParameterNode, Scale @@ -12,43 +13,57 @@ def build_api_values_history(values_history): def get_value(date, values): - candidates = sorted([ - (start_date, value) - for start_date, value in values.items() - if start_date <= date # dates are lexicographically ordered and can be sorted - ], reverse = True) + candidates = sorted( + [ + (start_date, value) + for start_date, value in values.items() + if start_date + <= date # dates are lexicographically ordered and can be sorted + ], + reverse=True, + ) if candidates: return candidates[0][1] - else: - return None + return None def build_api_scale(scale, value_key_name): # preprocess brackets for a scale with 'rates' or 'amounts' - brackets = [{ - 'thresholds': build_api_values_history(bracket.threshold), - 'values': build_api_values_history(getattr(bracket, value_key_name)) - } for bracket in scale.brackets] - - dates = set(sum( - [list(bracket['thresholds'].keys()) - + list(bracket['values'].keys()) for bracket in brackets], - [])) # flatten the dates and remove duplicates + brackets = [ + { + "thresholds": build_api_values_history(bracket.threshold), + "values": build_api_values_history(getattr(bracket, value_key_name)), + } + for bracket in scale.brackets + ] + + dates = set( + functools.reduce( + operator.iadd, + [ + list(bracket["thresholds"].keys()) + list(bracket["values"].keys()) + for bracket in brackets + ], + [], + ), + ) # flatten the dates and remove duplicates # We iterate on all dates as we need to build the whole scale for each of them api_scale = {} for date in dates: for bracket in brackets: - threshold_value = get_value(date, bracket['thresholds']) + threshold_value = get_value(date, bracket["thresholds"]) if threshold_value is not None: - rate_or_amount_value = get_value(date, bracket['values']) + rate_or_amount_value = get_value(date, bracket["values"]) api_scale[date] = api_scale.get(date) or {} api_scale[date][threshold_value] = rate_or_amount_value # Handle stopped parameters: a parameter is stopped if its first bracket is stopped - latest_date_first_threshold = max(brackets[0]['thresholds'].keys()) - latest_value_first_threshold = brackets[0]['thresholds'][latest_date_first_threshold] + latest_date_first_threshold = max(brackets[0]["thresholds"].keys()) + latest_value_first_threshold = brackets[0]["thresholds"][ + latest_date_first_threshold + ] if latest_value_first_threshold is None: api_scale[latest_date_first_threshold] = None @@ -57,45 +72,51 @@ def build_api_scale(scale, value_key_name): def build_source_url(absolute_file_path, country_package_metadata): - relative_path = absolute_file_path.replace(country_package_metadata['location'], '') - return '{}/blob/{}{}'.format( - country_package_metadata['repository_url'], - country_package_metadata['version'], - relative_path - ) + relative_path = absolute_file_path.replace(country_package_metadata["location"], "") + return "{}/blob/{}{}".format( + country_package_metadata["repository_url"], + country_package_metadata["version"], + relative_path, + ) def build_api_parameter(parameter, country_package_metadata): api_parameter = { - 'description': getattr(parameter, "description", None), - 'id': parameter.name, - 'metadata': parameter.metadata - } + "description": getattr(parameter, "description", None), + "id": parameter.name, + "metadata": parameter.metadata, + } if parameter.file_path: - api_parameter['source'] = build_source_url(parameter.file_path, country_package_metadata) + api_parameter["source"] = build_source_url( + parameter.file_path, + country_package_metadata, + ) if isinstance(parameter, Parameter): if parameter.documentation: - api_parameter['documentation'] = parameter.documentation.strip() - api_parameter['values'] = build_api_values_history(parameter) + api_parameter["documentation"] = parameter.documentation.strip() + api_parameter["values"] = build_api_values_history(parameter) elif isinstance(parameter, Scale): - if 'rate' in parameter.brackets[0].children: - api_parameter['brackets'] = build_api_scale(parameter, 'rate') - elif 'amount' in parameter.brackets[0].children: - api_parameter['brackets'] = build_api_scale(parameter, 'amount') + if "rate" in parameter.brackets[0].children: + api_parameter["brackets"] = build_api_scale(parameter, "rate") + elif "amount" in parameter.brackets[0].children: + api_parameter["brackets"] = build_api_scale(parameter, "amount") elif isinstance(parameter, ParameterNode): if parameter.documentation: - api_parameter['documentation'] = parameter.documentation.strip() - api_parameter['subparams'] = { + api_parameter["documentation"] = parameter.documentation.strip() + api_parameter["subparams"] = { child_name: { - 'description': child.description, - } - for child_name, child in parameter.children.items() + "description": child.description, } + for child_name, child in parameter.children.items() + } return api_parameter def build_parameters(tax_benefit_system, country_package_metadata): return { - parameter.name.replace('.', '/'): build_api_parameter(parameter, country_package_metadata) + parameter.name.replace(".", "/"): build_api_parameter( + parameter, + country_package_metadata, + ) for parameter in tax_benefit_system.parameters.get_descendants() - } + } diff --git a/openfisca_web_api/loader/spec.py b/openfisca_web_api/loader/spec.py index fde2818c33..4a163bd91f 100644 --- a/openfisca_web_api/loader/spec.py +++ b/openfisca_web_api/loader/spec.py @@ -1,76 +1,119 @@ -# -*- coding: utf-8 -*- - import os -import yaml from copy import deepcopy -import dpath +import dpath.util +import yaml from openfisca_core.indexed_enums import Enum from openfisca_web_api import handlers - -OPEN_API_CONFIG_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.pardir, 'openAPI.yml') +OPEN_API_CONFIG_FILE = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + os.path.pardir, + "openAPI.yml", +) def build_openAPI_specification(api_data): - tax_benefit_system = api_data['tax_benefit_system'] - file = open(OPEN_API_CONFIG_FILE, 'r') + tax_benefit_system = api_data["tax_benefit_system"] + file = open(OPEN_API_CONFIG_FILE) spec = yaml.safe_load(file) - country_package_name = api_data['country_package_metadata']['name'].title() - dpath.new(spec, 'info/title', spec['info']['title'].replace("{COUNTRY_PACKAGE_NAME}", country_package_name)) - dpath.new(spec, 'info/description', spec['info']['description'].replace("{COUNTRY_PACKAGE_NAME}", country_package_name)) - dpath.new(spec, 'info/version', api_data['country_package_metadata']['version']) + country_package_name = api_data["country_package_metadata"]["name"].title() + country_package_version = api_data["country_package_metadata"]["version"] + dpath.util.new( + spec, + "info/title", + spec["info"]["title"].replace("{COUNTRY_PACKAGE_NAME}", country_package_name), + ) + dpath.util.new( + spec, + "info/description", + spec["info"]["description"].replace( + "{COUNTRY_PACKAGE_NAME}", + country_package_name, + ), + ) + dpath.util.new( + spec, + "info/version", + spec["info"]["version"].replace( + "{COUNTRY_PACKAGE_VERSION}", + country_package_version, + ), + ) for entity in tax_benefit_system.entities: name = entity.key.title() - spec['definitions'][name] = get_entity_json_schema(entity, tax_benefit_system) + spec["components"]["schemas"][name] = get_entity_json_schema( + entity, + tax_benefit_system, + ) situation_schema = get_situation_json_schema(tax_benefit_system) - dpath.new(spec, 'definitions/SituationInput', situation_schema) - dpath.new(spec, 'definitions/SituationOutput', situation_schema.copy()) - dpath.new(spec, 'definitions/Trace/properties/entitiesDescription/properties', { - entity.plural: {'type': 'array', 'items': {"type": "string"}} - for entity in tax_benefit_system.entities - }) + dpath.util.new(spec, "components/schemas/SituationInput", situation_schema) + dpath.util.new(spec, "components/schemas/SituationOutput", situation_schema.copy()) + dpath.util.new( + spec, + "components/schemas/Trace/properties/entitiesDescription/properties", + { + entity.plural: {"type": "array", "items": {"type": "string"}} + for entity in tax_benefit_system.entities + }, + ) # Get example from the served tax benefist system - if tax_benefit_system.open_api_config.get('parameter_example'): - parameter_id = tax_benefit_system.open_api_config['parameter_example'] - parameter_path = parameter_id.replace('.', '/') - parameter_example = api_data['parameters'][parameter_path] + if tax_benefit_system.open_api_config.get("parameter_example"): + parameter_id = tax_benefit_system.open_api_config["parameter_example"] + parameter_path = parameter_id.replace(".", "/") + parameter_example = api_data["parameters"][parameter_path] else: - parameter_example = next(iter(api_data['parameters'].values())) - dpath.new(spec, 'definitions/Parameter/example', parameter_example) + parameter_example = next(iter(api_data["parameters"].values())) + dpath.util.new(spec, "components/schemas/Parameter/example", parameter_example) - if tax_benefit_system.open_api_config.get('variable_example'): - variable_example = api_data['variables'][tax_benefit_system.open_api_config['variable_example']] + if tax_benefit_system.open_api_config.get("variable_example"): + variable_example = api_data["variables"][ + tax_benefit_system.open_api_config["variable_example"] + ] else: - variable_example = next(iter(api_data['variables'].values())) - dpath.new(spec, 'definitions/Variable/example', variable_example) - - if tax_benefit_system.open_api_config.get('simulation_example'): - simulation_example = tax_benefit_system.open_api_config['simulation_example'] - dpath.new(spec, 'definitions/SituationInput/example', simulation_example) - dpath.new(spec, 'definitions/SituationOutput/example', handlers.calculate(tax_benefit_system, deepcopy(simulation_example))) # calculate has side-effects - dpath.new(spec, 'definitions/Trace/example', handlers.trace(tax_benefit_system, simulation_example)) + variable_example = next(iter(api_data["variables"].values())) + dpath.util.new(spec, "components/schemas/Variable/example", variable_example) + + if tax_benefit_system.open_api_config.get("simulation_example"): + simulation_example = tax_benefit_system.open_api_config["simulation_example"] + dpath.util.new( + spec, + "components/schemas/SituationInput/example", + simulation_example, + ) + dpath.util.new( + spec, + "components/schemas/SituationOutput/example", + handlers.calculate(tax_benefit_system, deepcopy(simulation_example)), + ) # calculate has side-effects + dpath.util.new( + spec, + "components/schemas/Trace/example", + handlers.trace(tax_benefit_system, simulation_example), + ) else: - message = "No simulation example has been defined for this tax and benefit system. If you are the maintainer of {}, you can define an example by following this documentation: https://openfisca.org/doc/openfisca-web-api/config-openapi.html".format(country_package_name) - dpath.new(spec, 'definitions/SituationInput/example', message) - dpath.new(spec, 'definitions/SituationOutput/example', message) - dpath.new(spec, 'definitions/Trace/example', message) + message = f"No simulation example has been defined for this tax and benefit system. If you are the maintainer of {country_package_name}, you can define an example by following this documentation: https://openfisca.org/doc/openfisca-web-api/config-openapi.html" + dpath.util.new(spec, "components/schemas/SituationInput/example", message) + dpath.util.new(spec, "components/schemas/SituationOutput/example", message) + dpath.util.new(spec, "components/schemas/Trace/example", message) return spec def get_variable_json_schema(variable): result = { - 'type': 'object', - 'additionalProperties': {'type': variable.json_type}, - } + "type": "object", + "additionalProperties": {"type": variable.json_type}, + } if variable.value_type == Enum: - result['additionalProperties']['enum'] = [item.name for item in list(variable.possible_values)] + result["additionalProperties"]["enum"] = [ + item.name for item in list(variable.possible_values) + ] return result @@ -78,46 +121,48 @@ def get_variable_json_schema(variable): def get_entity_json_schema(entity, tax_benefit_system): if entity.is_person: return { - 'type': 'object', - 'properties': { + "type": "object", + "properties": { variable_name: get_variable_json_schema(variable) - for variable_name, variable in tax_benefit_system.get_variables(entity).items() - }, - 'additionalProperties': False, - } - else: - properties = {} - properties.update({ - role.plural or role.key: { - 'type': 'array', - "items": { - "type": "string" - } - } + for variable_name, variable in tax_benefit_system.get_variables( + entity, + ).items() + }, + "additionalProperties": False, + } + properties = {} + properties.update( + { + role.plural or role.key: {"type": "array", "items": {"type": "string"}} for role in entity.roles - }) - properties.update({ + }, + ) + properties.update( + { variable_name: get_variable_json_schema(variable) - for variable_name, variable in tax_benefit_system.get_variables(entity).items() - }) - return { - 'type': 'object', - 'properties': properties, - 'additionalProperties': False, - } + for variable_name, variable in tax_benefit_system.get_variables( + entity, + ).items() + }, + ) + return { + "type": "object", + "properties": properties, + "additionalProperties": False, + } def get_situation_json_schema(tax_benefit_system): return { - 'type': 'object', - 'additionalProperties': False, - 'properties': { + "type": "object", + "additionalProperties": False, + "properties": { entity.plural: { - 'type': 'object', - 'additionalProperties': { - "$ref": "#/definitions/{}".format(entity.key.title()) - } - } - for entity in tax_benefit_system.entities + "type": "object", + "additionalProperties": { + "$ref": f"#/components/schemas/{entity.key.title()}", + }, } - } + for entity in tax_benefit_system.entities + }, + } diff --git a/openfisca_web_api/loader/tax_benefit_system.py b/openfisca_web_api/loader/tax_benefit_system.py index 856f760008..358f960501 100644 --- a/openfisca_web_api/loader/tax_benefit_system.py +++ b/openfisca_web_api/loader/tax_benefit_system.py @@ -1,8 +1,6 @@ -# -*- coding: utf-8 -*- - import importlib -import traceback import logging +import traceback from os import linesep log = logging.getLogger(__name__) @@ -12,14 +10,18 @@ def build_tax_benefit_system(country_package_name): try: country_package = importlib.import_module(country_package_name) except ImportError: - message = linesep.join([traceback.format_exc(), - 'Could not import module `{}`.'.format(country_package_name), - 'Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.', - 'See more at .', - linesep]) + message = linesep.join( + [ + traceback.format_exc(), + f"Could not import module `{country_package_name}`.", + "Are you sure it is installed in your environment? If so, look at the stack trace above to determine the origin of this error.", + "See more at .", + linesep, + ], + ) raise ValueError(message) try: return country_package.CountryTaxBenefitSystem() except NameError: # Gunicorn swallows NameErrors. Force printing the stack trace. - log.error(traceback.format_exc()) + log.exception(traceback.format_exc()) raise diff --git a/openfisca_web_api/loader/variables.py b/openfisca_web_api/loader/variables.py index d9390fb3a2..6730dc0811 100644 --- a/openfisca_web_api/loader/variables.py +++ b/openfisca_web_api/loader/variables.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import datetime import inspect import textwrap @@ -10,8 +8,8 @@ def get_next_day(date): parsed_date = date - next_day = parsed_date + datetime.timedelta(days = 1) - return next_day.isoformat().split('T')[0] + next_day = parsed_date + datetime.timedelta(days=1) + return next_day.isoformat().split("T")[0] def get_default_value(variable): @@ -25,84 +23,99 @@ def get_default_value(variable): return default_value -def build_source_url(country_package_metadata, source_file_path, start_line_number, source_code): - nb_lines = source_code.count('\n') - return '{}/blob/{}{}#L{}-L{}'.format( - country_package_metadata['repository_url'], - country_package_metadata['version'], +def build_source_url( + country_package_metadata, + source_file_path, + start_line_number, + source_code, +): + nb_lines = source_code.count("\n") + return "{}/blob/{}{}#L{}-L{}".format( + country_package_metadata["repository_url"], + country_package_metadata["version"], source_file_path, start_line_number, start_line_number + nb_lines - 1, - ) + ) -def build_formula(formula, country_package_metadata, source_file_path, tax_benefit_system): +def build_formula(formula, country_package_metadata, source_file_path): source_code, start_line_number = inspect.getsourcelines(formula) - source_code = textwrap.dedent(''.join(source_code)) + source_code = textwrap.dedent("".join(source_code)) api_formula = { - 'source': build_source_url( + "source": build_source_url( country_package_metadata, source_file_path, start_line_number, - source_code - ), - 'content': source_code, - } + source_code, + ), + "content": source_code, + } if formula.__doc__: - api_formula['documentation'] = textwrap.dedent(formula.__doc__) + api_formula["documentation"] = textwrap.dedent(formula.__doc__) return api_formula -def build_formulas(formulas, country_package_metadata, source_file_path, tax_benefit_system): +def build_formulas(formulas, country_package_metadata, source_file_path): return { - start_date: build_formula(formula, country_package_metadata, source_file_path, tax_benefit_system) + start_date: build_formula(formula, country_package_metadata, source_file_path) for start_date, formula in formulas.items() - } + } -def build_variable(variable, country_package_metadata, tax_benefit_system): - comments, source_file_path, source_code, start_line_number = variable.get_introspection_data(tax_benefit_system) +def build_variable(variable, country_package_metadata): + ( + source_file_path, + source_code, + start_line_number, + ) = variable.get_introspection_data() result = { - 'id': variable.name, - 'description': variable.label, - 'valueType': VALUE_TYPES[variable.value_type]['formatted_value_type'], - 'defaultValue': get_default_value(variable), - 'definitionPeriod': variable.definition_period.upper(), - 'entity': variable.entity.key, - } + "id": variable.name, + "description": variable.label, + "valueType": VALUE_TYPES[variable.value_type]["formatted_value_type"], + "defaultValue": get_default_value(variable), + "definitionPeriod": variable.definition_period.upper(), + "entity": variable.entity.key, + } if source_code: - result['source'] = build_source_url( + result["source"] = build_source_url( country_package_metadata, source_file_path, start_line_number, - source_code - ) + source_code, + ) if variable.documentation: - result['documentation'] = variable.documentation.strip() + result["documentation"] = variable.documentation.strip() if variable.reference: - result['references'] = variable.reference + result["references"] = variable.reference if len(variable.formulas) > 0: - result['formulas'] = build_formulas(variable.formulas, country_package_metadata, source_file_path, tax_benefit_system) + result["formulas"] = build_formulas( + variable.formulas, + country_package_metadata, + source_file_path, + ) if variable.end: - result['formulas'][get_next_day(variable.end)] = None + result["formulas"][get_next_day(variable.end)] = None if variable.value_type == Enum: - result['possibleValues'] = {item.name: item.value for item in list(variable.possible_values)} + result["possibleValues"] = { + item.name: item.value for item in list(variable.possible_values) + } return result def build_variables(tax_benefit_system, country_package_metadata): return { - name: build_variable(variable, country_package_metadata, tax_benefit_system) + name: build_variable(variable, country_package_metadata) for name, variable in tax_benefit_system.variables.items() - } + } diff --git a/openfisca_web_api/openAPI.yml b/openfisca_web_api/openAPI.yml index d0c52f9a14..ce935e5596 100644 --- a/openfisca_web_api/openAPI.yml +++ b/openfisca_web_api/openAPI.yml @@ -1,374 +1,434 @@ -swagger: "2.0" +openapi: "3.0.0" + info: title: "{COUNTRY_PACKAGE_NAME} Web API" description: "The OpenFisca Web API lets you get up-to-date information and formulas included in the {COUNTRY_PACKAGE_NAME} legislation." - version: null + version: "{COUNTRY_PACKAGE_VERSION}" termsOfService: "https://openfisca.org/doc/licence.html" contact: email: "contact@openfisca.org" license: name: "AGPL" url: "https://www.gnu.org/licenses/agpl-3.0" -host: null -schemes: null + tags: - name: "Parameters" description: "A parameter is a numeric property of the legislation that can evolve over time." externalDocs: description: "Parameters documentation" url: "https://openfisca.org/doc/key-concepts/parameters.html" + - name: "Variables" description: "A variable depends on a person, or an entity (e.g. zip code, salary, income tax)." externalDocs: description: "Variables documentation" url: "https://openfisca.org/doc/key-concepts/variables.html" + - name: "Entities" description: "An entity is a person of a group of individuals (such as a household)." externalDocs: description: "Entities documentation" url: "https://openfisca.org/doc/key-concepts/person,_entities,_role.html" + - name: "Calculations" + - name: "Documentation" + +components: + schemas: + Parameter: + type: "object" + properties: + values: + $ref: "#/components/schemas/Values" + brackets: + type: "object" + additionalProperties: + $ref: "#/components/schemas/Brackets" + subparams: + type: "object" + additionalProperties: + type: "object" + properties: + definition: + type: "string" + metadata: + type: "object" + description: + type: "string" + id: + type: "integer" + format: "string" + source: + type: "string" + + Parameters: + type: "object" + additionalProperties: + type: "object" + properties: + description: + type: "string" + href: + type: "string" + + Variable: + type: "object" + properties: + defaultValue: + type: "string" + definitionPeriod: + type: "string" + enum: + - "MONTH" + - "YEAR" + - "ETERNITY" + description: + type: "string" + entity: + type: "string" + formulas: + type: "object" + additionalProperties: + $ref: "#/components/schemas/Formula" + id: + type: "string" + reference: + type: "array" + items: + type: "string" + source: + type: "string" + valueType: + type: "string" + enum: + - "Int" + - "Float" + - "Boolean" + - "Date" + - "String" + + Variables: + type: "object" + additionalProperties: + type: "object" + properties: + description: + type: "string" + href: + type: "string" + + Formula: + type: "object" + properties: + content: + type: "string" + source: + type: "string" + + Brackets: + type: "object" + additionalProperties: + type: "number" + format: "float" + + Values: + description: "All keys are ISO dates. Values can be numbers, booleans, or arrays of a single type (number, boolean or string)." + type: "object" + additionalProperties: + $ref: "#/components/schemas/Value" + # propertyNames: # this keyword is part of JSON Schema but is not supported in OpenAPI v3.0.0 + # pattern: "^[12][0-9]{3}-[01][0-9]-[0-3][0-9]$" # all keys are ISO dates + + Value: + oneOf: + - type: "boolean" + - type: "number" + format: "float" + - type: "array" + items: + oneOf: + - type: "string" + - type: "number" + + Entities: + type: "object" + properties: + description: + type: "string" + documentation: + type: "string" + plural: + type: "string" + roles: + type: "object" + additionalProperties: + $ref: "#/components/schemas/Roles" + + Roles: + type: "object" + properties: + description: + type: "string" + max: + type: "integer" + plural: + type: "string" + + Trace: + type: "object" + properties: + requestedCalculations: + type: "array" + items: + type: "string" + entitiesDescription: + type: "object" + additionalProperties: false # Will be dynamically added by the Web API + trace: + type: "object" + additionalProperties: + type: "object" + properties: + value: + type: "array" + items: {} + dependencies: + type: "array" + items: + type: "string" + parameters: + type: "object" + additionalProperties: + type: "object" + + headers: + Country-Package: + description: "The name of the country package currently loaded in this API server" + schema: + type: "string" + + Country-Package-Version: + description: "The version of the country package currently loaded in this API server" + schema: + type: "string" + pattern: "^(0|[1-9][0-9]*)\\.(0|[1-9][0-9]*)\\.(0|[1-9][0-9]*)(?:-((?:0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$" # adapted from https://semver.org/#is-there-a-suggested-regular-expression-regex-to-check-a-semver-string + paths: /calculate: post: summary: "Run a simulation" tags: - - Calculations + - "Calculations" operationId: "calculate" - consumes: - - "application/json" - produces: - - "application/json" - parameters: - - in: "body" - name: "Situation" + requestBody: description: "Describe the situation (persons and entities). Add the variable you wish to calculate in the proper entity, with null as the value. Learn more in our official documentation: https://openfisca.org/doc/openfisca-web-api/input-output-data.html" required: true - schema: - $ref: "#/definitions/SituationInput" + content: + application/json: + schema: + $ref: "#/components/schemas/SituationInput" responses: 200: description: "The calculation result is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/SituationOutput" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/SituationOutput" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "A variable mentioned in the input situation does not exist in the loaded tax and benefit system. Details are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 400: description: "The request is invalid. Details about the error are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /parameters: get: tags: - "Parameters" summary: "List all available parameters" operationId: "getParameters" - produces: - - "application/json" responses: 200: description: "The list of parameters is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Parameters" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Parameters" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /parameter/{parameterID}: get: tags: - "Parameters" summary: "Get information about a specific parameter" operationId: "getParameter" - produces: - - "application/json" parameters: - name: "parameterID" in: "path" description: "ID of parameter. IDs can be obtained by enumerating the /parameters endpoint" required: true - type: "string" + schema: + type: "string" responses: 200: description: "The requested parameter's information is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Parameter" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Parameter" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "The requested parameter does not exist" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /variables: get: tags: - "Variables" summary: "List all available variables" operationId: "getVariables" - produces: - - "application/json" responses: 200: description: "The list of variables is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Variables" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Variables" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /variable/{variableID}: get: tags: - "Variables" summary: "Get information about a specific variable" operationId: "getVariable" - produces: - - "application/json" parameters: - name: "variableID" in: "path" description: "ID of a variable. IDs can be obtained by enumerating the /variables endpoint." required: true - type: "string" + schema: + type: "string" responses: 200: description: "The requested variable's information is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Variable" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Variable" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "The requested variable does not exist" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /entities: get: tags: - "Entities" summary: "List all available Entities" - operationId: "getVariables" - produces: - - "application/json" + operationId: "getEntities" responses: 200: description: "The list of the entities as well as their information is sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Entities" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Entities" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /trace: post: summary: "Explore a simulation's steps in details." tags: - - Calculations + - "Calculations" operationId: "trace" - consumes: - - "application/json" - produces: - - "application/json" - parameters: - - in: "body" - name: "Situation" + requestBody: description: "Describe the situation (persons and entities). Add the variable you wish to calculate in the proper entity, with null as the value." required: true - schema: - $ref: "#/definitions/SituationInput" + content: + application/json: + schema: + $ref: "#/components/schemas/SituationInput" responses: 200: description: "The calculation details are sent back in the response body" + content: + application/json: + schema: + $ref: "#/components/schemas/Trace" headers: - $ref: "#/commons/Headers" - schema: - $ref: "#/definitions/Trace" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 404: description: "A variable mentioned in the input situation does not exist in the loaded tax and benefit system. Details are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" 400: description: "The request is invalid. Details about the error are sent back in the response body" headers: - $ref: "#/commons/Headers" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" + /spec: get: - summary: Provide the API documentation in an OpenAPI format + summary: "Provide the API documentation in an OpenAPI format" tags: - - Documentation - operationId: spec - produces: - - application/json + - "Documentation" + operationId: "spec" responses: 200: - description: The API documentation is sent back in the response body + description: "The API documentation is sent back in the response body" headers: - $ref: "#/commons/Headers" - -definitions: - Parameter: - type: "object" - properties: - values: - $ref: "#/definitions/Values" - brackets: - type: "object" - additionalProperties: - $ref: "#/definitions/Brackets" - subparams: - type: "object" - additionalProperties: - type: "object" - properties: - definition: - type: "string" - metadata: - type: "object" - description: - type: "string" - id: - type: "integer" - format: "string" - source: - type: "string" - example: null - - Parameters: - type: "object" - additionalProperties: - type: "object" - properties: - description: - type: "string" - href: - type: "string" - - Variable: - type: "object" - properties: - defaultValue: - type: "string" - definitionPeriod: - type: string - enum: - - MONTH - - YEAR - - ETERNITY - description: - type: "string" - entity: - type: "string" - formulas: - type: "object" - additionalProperties: - $ref: "#/definitions/Formula" - id: - type: "string" - reference: - type: "array" - items: - type: "string" - source: - type: "string" - valueType: - type: "string" - enum: - - Int - - Float - - Boolean - - Date - - String - example: null - - Variables: - type: "object" - additionalProperties: - type: "object" - properties: - description: - type: "string" - href: - type: "string" - - Formula: - type: "object" - properties: - content: - type: "string" - source: - type: "string" - - Brackets: - type: "object" - additionalProperties: - type: "number" - format: "float" - - Values: - description: All keys are ISO dates. Values can be numbers, booleans, or arrays of a single type (number, boolean or string). - type: "object" - additionalProperties: true -# propertyNames: # this keyword is part of JSON Schema but is not supported in OpenAPI Specification at the time of writing, see https://swagger.io/docs/specification/data-models/keywords/#unsupported -# pattern: "^[12][0-9]{3}-[01][0-9]-[0-3][0-9]$" # all keys are ISO dates - - Entities: - type: "object" - properties: - description: - type: "string" - documentation: - type: "string" - plural: - type: "string" - roles: - type: "object" - additionalProperties: - $ref: "#/definitions/Roles" - Roles: - type: "object" - properties: - description: - type: "string" - max: - type: "integer" - plural: - type: "string" - SituationInput: null - SituationOutput: null - - Trace: - type: object - properties: - requestedCalculations: - type: array - items: - type: string - entitiesDescription: - type: object - properties: null # Will be dynamically added by the Web API - trace: - type: object - additionalProperties: - type: object - properties: - value: - type: array - items: - type: any - dependencies: - type: array - items: - type: string - parameters: - type: object - additionalProperties: - type: object - - example: null - -commons: - Headers: - Country-Package: - description: "The name of the country package currently loaded in this API server" - type: "string" - Country-Package-Version: - description: "The version of the country package currently loaded in this API server" - type: "string" + Country-Package: + $ref: "#/components/headers/Country-Package" + Country-Package-Version: + $ref: "#/components/headers/Country-Package-Version" diff --git a/openfisca_web_api/scripts/serve.py b/openfisca_web_api/scripts/serve.py index 428cf2b965..6ba89f440a 100644 --- a/openfisca_web_api/scripts/serve.py +++ b/openfisca_web_api/scripts/serve.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- - -import sys import logging +import sys from openfisca_core.scripts import build_tax_benefit_system from openfisca_web_api.app import create_app from openfisca_web_api.errors import handle_import_error try: - from gunicorn.app.base import BaseApplication from gunicorn import config + from gunicorn.app.base import BaseApplication except ImportError as error: handle_import_error(error) @@ -18,10 +16,10 @@ Define the `openfisca serve` command line interface. """ -DEFAULT_PORT = '5000' -HOST = '127.0.0.1' -DEFAULT_WORKERS_NUMBER = '3' -DEFAULT_TIMEOUT = 120 +DEFAULT_PORT = "5000" +HOST = "127.0.0.1" +DEFAULT_WORKERS_NUMBER = "3" +DEFAULT_TIMEOUT = 1200 log = logging.getLogger(__name__) @@ -33,7 +31,7 @@ def read_user_configuration(default_configuration, command_line_parser): if args.configuration_file: file_configuration = {} - with open(args.configuration_file, "r") as file: + with open(args.configuration_file) as file: exec(file.read(), {}, file_configuration) # Configuration file overloads default configuration @@ -42,10 +40,13 @@ def read_user_configuration(default_configuration, command_line_parser): # Command line configuration overloads all configuration gunicorn_parser = config.Config().parser() configuration = update(configuration, vars(args)) - configuration = update(configuration, vars(gunicorn_parser.parse_args(unknown_args))) - if configuration['args']: + configuration = update( + configuration, + vars(gunicorn_parser.parse_args(unknown_args)), + ) + if configuration["args"]: command_line_parser.print_help() - log.error('Unexpected positional argument {}'.format(configuration['args'])) + log.error("Unexpected positional argument {}".format(configuration["args"])) sys.exit(1) return configuration @@ -56,42 +57,43 @@ def update(configuration, new_options): if value is not None: configuration[key] = value if key == "port": - configuration['bind'] = configuration['bind'][:-4] + str(configuration['port']) + configuration["bind"] = configuration["bind"][:-4] + str( + configuration["port"], + ) return configuration class OpenFiscaWebAPIApplication(BaseApplication): - - def __init__(self, options): + def __init__(self, options) -> None: self.options = options - super(OpenFiscaWebAPIApplication, self).__init__() + super().__init__() - def load_config(self): + def load_config(self) -> None: for key, value in self.options.items(): if key in self.cfg.settings: self.cfg.set(key.lower(), value) def load(self): tax_benefit_system = build_tax_benefit_system( - self.options.get('country_package'), - self.options.get('extensions'), - self.options.get('reforms') - ) + self.options.get("country_package"), + self.options.get("extensions"), + self.options.get("reforms"), + ) return create_app( tax_benefit_system, - self.options.get('tracker_url'), - self.options.get('tracker_idsite'), - self.options.get('tracker_token'), - self.options.get('welcome_message') - ) + self.options.get("tracker_url"), + self.options.get("tracker_idsite"), + self.options.get("tracker_token"), + self.options.get("welcome_message"), + ) -def main(parser): +def main(parser) -> None: configuration = { - 'port': DEFAULT_PORT, - 'bind': '{}:{}'.format(HOST, DEFAULT_PORT), - 'workers': DEFAULT_WORKERS_NUMBER, - 'timeout': DEFAULT_TIMEOUT, - } + "port": DEFAULT_PORT, + "bind": f"{HOST}:{DEFAULT_PORT}", + "workers": DEFAULT_WORKERS_NUMBER, + "timeout": DEFAULT_TIMEOUT, + } configuration = read_user_configuration(configuration, parser) OpenFiscaWebAPIApplication(configuration).run() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..ce4ba3779f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +target-version = ["py39", "py310", "py311", "py312"] diff --git a/setup.cfg b/setup.cfg index c6f0809aae..9b8ce699bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,42 +1,98 @@ -; E128/133: We prefer hang-closing visual indents -; E251: We prefer `function(x = 1)` over `function(x=1)` -; E501: We do not enforce a maximum line length -; F403/405: We ignore * imports -; W503/504: We break lines before binary operators (Knuth's style) +# C011X: We (progressively) document the code base. +# D10X: We (progressively) check docstrings (see https://www.pydocstyle.org/en/2.1.1/error_codes.html#grouping). +# DARXXX: We (progressively) check docstrings (see https://github.com/terrencepreilly/darglint#error-codes). +# E203: We ignore a false positive in whitespace before ":" (see https://github.com/PyCQA/pycodestyle/issues/373). +# F403/405: We ignore * imports. +# R0401: We avoid cyclic imports —required for unit/doc tests. +# RST301: We use Google Python Style (see https://pypi.org/project/flake8-rst-docstrings/). +# W503/504: We break lines before binary operators (Knuth's style). [flake8] -hang-closing = true -extend-ignore = D -ignore = E128,E251,F403,F405,E501,W503,W504 -in-place = true -include-in-doctest = openfisca_core/commons openfisca_core/entities openfisca_core/indexed_enums openfisca_core/types -jobs = 0 -rst-directives = attribute, deprecated, seealso, versionadded, versionchanged -rst-roles = any, attr, class, data, exc, func, meth, obj -strictness = short +convention = google +docstring_style = google +extend-ignore = D +ignore = B019, E203, E501, F405, E701, E704, RST212, RST213, RST301, RST306, W503 +in-place = true +include-in-doctest = + openfisca_core/commons + openfisca_core/entities + openfisca_core/holders + openfisca_core/periods + openfisca_core/projectors +max-line-length = 88 +per-file-ignores = + */types.py:D101,D102,E301,E704,W504 + */test_*.py:D101,D102,D103 + */__init__.py:F401 + */__init__.pyi:E302,E704 +rst-directives = attribute, deprecated, seealso, versionadded, versionchanged +rst-roles = any, attr, class, exc, func, meth, mod, obj +strictness = short + +[pylint.MASTER] +load-plugins = pylint_per_file_ignores [pylint.message_control] -disable = all -jobs = 0 -score = no +disable = all +enable = C0115, C0116, R0401 +per-file-ignores = + types.py:C0115,C0116 + /tests/:C0116 +score = no -[tool:pytest] -addopts = --cov-report=term-missing:skip-covered --cov-fail-under=78.69 --doctest-modules --disable-pytest-warnings --showlocals -doctest_optionflags = ELLIPSIS IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE -python_files = **/*.py -testpaths = openfisca_core/commons openfisca_core/entities openfisca_core/indexed_enums openfisca_core/types tests +[isort] +case_sensitive = true +combine_as_imports = true +force_alphabetical_sort_within_sections = false +group_by_package = true +honor_noqa = true +include_trailing_comma = true +known_first_party = openfisca_core +known_openfisca = openfisca_country_template, openfisca_extension_template +known_typing = *collections.abc*, *typing*, *typing_extensions* +known_types = *types* +multi_line_output = 3 +profile = black +py_version = 39 +sections = FUTURE, TYPING, TYPES, STDLIB, THIRDPARTY, OPENFISCA, FIRSTPARTY, LOCALFOLDER -[mypy] -ignore_missing_imports = True +[coverage:paths] +source = . */site-packages -[mypy-openfisca_core.commons.tests.*] -ignore_errors = True +[coverage:run] +branch = true +source = openfisca_core, openfisca_web_api -[mypy-openfisca_core.entities.tests.*] -ignore_errors = True +[coverage:report] +fail_under = 75 +show_missing = true +skip_covered = true +skip_empty = true -[mypy-openfisca_core.indexed_enums.tests.*] -ignore_errors = True +[tool:pytest] +addopts = --disable-pytest-warnings --doctest-modules --showlocals +doctest_optionflags = ELLIPSIS IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE +python_files = **/*.py +testpaths = tests + +[mypy] +check_untyped_defs = false +disallow_any_decorated = false +disallow_any_explicit = false +disallow_any_expr = false +disallow_any_unimported = false +follow_imports = skip +ignore_missing_imports = true +implicit_reexport = false +install_types = true +mypy_path = stubs +non_interactive = true +plugins = numpy.typing.mypy_plugin +pretty = true +python_version = 3.9 +strict = false +warn_no_return = true +warn_unreachable = true -[mypy-openfisca_core.scripts.*] -ignore_errors = True +[mypy-openfisca_core.*.tests.*] +ignore_errors = True diff --git a/setup.py b/setup.py index 512cdf08a0..fcfe490269 100644 --- a/setup.py +++ b/setup.py @@ -1,85 +1,119 @@ -#! /usr/bin/env python +"""Package config file. -from setuptools import setup, find_packages +This file contains all package's metadata, including the current version and +its third-party dependencies. + +Note: + For integration testing, OpenFisca-Core relies on two other packages, + listed below. Because these packages rely at the same time on + OpenFisca-Core, adding them as official dependencies creates a resolution + loop that makes it hard to contribute. We've therefore decided to install + them via the task manager (`make install-test`):: + + openfisca-country-template = "*" + openfisca-extension-template = "*" + +""" + +from pathlib import Path + +from setuptools import find_packages, setup + +# Read the contents of our README file for PyPi +this_directory = Path(__file__).parent +long_description = (this_directory / "README.md").read_text() # Please make sure to cap all dependency versions, in order to avoid unwanted # functional and integration breaks caused by external code updates. - +# DO NOT add space between '>=' and version number as it break conda build. general_requirements = [ - 'dpath >= 1.5.0, < 2.0.0', - 'nptyping >= 1.4.3, < 2.0.0', - 'numexpr >= 2.7.0, <= 3.0', - 'numpy >= 1.11, < 1.21', - 'psutil >= 5.4.7, < 6.0.0', - 'pytest >= 4.4.1, < 6.0.0', # For openfisca test - 'PyYAML >= 3.10', - 'sortedcontainers == 2.2.2', - 'typing-extensions >= 3.0.0.0, < 4.0.0.0', - ] + "PyYAML >=6.0, <7.0", + "StrEnum >=0.4.8, <0.5.0", # 3.11.x backport + "dpath >=2.1.4, <3.0", + "numexpr >=2.8.4, <3.0", + "numpy >=1.24.2, <2.0", + "pendulum >=3.0.0, <4.0.0", + "psutil >=5.9.4, <6.0", + "pytest >=8.3.3, <9.0", + "sortedcontainers >=2.4.0, <3.0", + "typing_extensions >=4.5.0, <5.0", +] api_requirements = [ - 'flask == 1.1.2', - 'flask-cors == 3.0.10', - 'gunicorn >= 20.0.0, < 21.0.0', - 'werkzeug >= 1.0.0, < 2.0.0', - ] + "Flask >=2.2.3, <3.0", + "Flask-Cors >=3.0.10, <4.0", + "gunicorn >=21.0, <22.0", + "Werkzeug >=2.2.3, <3.0", +] dev_requirements = [ - 'autopep8 >= 1.4.0, < 1.6.0', - 'darglint == 1.8.0', - 'flake8 >= 3.9.0, < 4.0.0', - 'flake8-bugbear >= 19.3.0, < 20.0.0', - 'flake8-docstrings == 1.6.0', - 'flake8-print >= 3.1.0, < 4.0.0', - 'flake8-rst-docstrings < 1.0.0', - 'mypy == 0.910', - 'openfisca-country-template >= 3.10.0, < 4.0.0', - 'openfisca-extension-template >= 1.2.0rc0, < 2.0.0', - 'pylint == 2.10.2', - 'pytest-cov >= 2.6.1, < 3.0.0', - 'types-PyYAML == 5.4.10', - 'types-setuptools == 57.0.2', - ] + api_requirements + "black >=24.8.0, <25.0", + "coverage >=7.6.1, <8.0", + "darglint >=1.8.1, <2.0", + "flake8 >=7.1.1, <8.0.0", + "flake8-bugbear >=24.8.19, <25.0", + "flake8-docstrings >=1.7.0, <2.0", + "flake8-print >=5.0.0, <6.0", + "flake8-rst-docstrings >=0.3.0, <0.4.0", + "idna >=3.10, <4.0", + "isort >=5.13.2, <6.0", + "mypy >=1.11.2, <2.0", + "openapi-spec-validator >=0.7.1, <0.8.0", + "pylint >=3.3.1, <4.0", + "pylint-per-file-ignores >=1.3.2, <2.0", + "pyright >=1.1.382, <2.0", + "ruff >=0.6.7, <1.0", + "ruff-lsp >=0.0.57, <1.0", + "xdoctest >=1.2.0, <2.0", + *api_requirements, +] setup( - name = 'OpenFisca-Core', - version = '35.6.0', - author = 'OpenFisca Team', - author_email = 'contact@openfisca.org', - classifiers = [ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'Operating System :: POSIX', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Topic :: Scientific/Engineering :: Information Analysis', - ], - description = 'A versatile microsimulation free software', - keywords = 'benefit microsimulation social tax', - license = 'https://www.fsf.org/licensing/licenses/agpl-3.0.html', - url = 'https://github.com/openfisca/openfisca-core', - - data_files = [ + name="OpenFisca-Core", + version="42.0.4", + author="OpenFisca Team", + author_email="contact@openfisca.org", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: GNU Affero General Public License v3", + "Operating System :: POSIX", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Information Analysis", + ], + description="A versatile microsimulation free software", + keywords="benefit microsimulation social tax", + license="https://www.fsf.org/licensing/licenses/agpl-3.0.html", + license_files=("LICENSE",), + url="https://github.com/openfisca/openfisca-core", + long_description=long_description, + long_description_content_type="text/markdown", + data_files=[ ( - 'share/openfisca/openfisca-core', - ['CHANGELOG.md', 'LICENSE', 'README.md'], - ), + "share/openfisca/openfisca-core", + ["CHANGELOG.md", "README.md"], + ), + ], + entry_points={ + "console_scripts": [ + "openfisca=openfisca_core.scripts.openfisca_command:main", + "openfisca-run-test=openfisca_core.scripts.openfisca_command:main", + ], + }, + extras_require={ + "web-api": api_requirements, + "dev": dev_requirements, + "ci": [ + "build >=0.10.0, <0.11.0", + "coveralls >=4.0.1, <5.0", + "twine >=5.1.1, <6.0", + "wheel >=0.40.0, <0.41.0", ], - entry_points = { - 'console_scripts': [ - 'openfisca=openfisca_core.scripts.openfisca_command:main', - 'openfisca-run-test=openfisca_core.scripts.openfisca_command:main', - ], - }, - extras_require = { - 'web-api': api_requirements, - 'dev': dev_requirements, - 'tracker': [ - 'openfisca-tracker == 0.4.0', - ], - }, - include_package_data = True, # Will read MANIFEST.in - install_requires = general_requirements, - packages = find_packages(exclude=['tests*']), - ) + "tracker": ["OpenFisca-Tracker >=0.4.0, <0.5.0"], + }, + include_package_data=True, # Will read MANIFEST.in + install_requires=general_requirements, + packages=find_packages(exclude=["tests*"]), +) diff --git a/stubs/numexpr/__init__.pyi b/stubs/numexpr/__init__.pyi new file mode 100644 index 0000000000..f9ada73c3b --- /dev/null +++ b/stubs/numexpr/__init__.pyi @@ -0,0 +1,9 @@ +from numpy.typing import NDArray + +import numpy + +def evaluate( + __ex: str, + *__args: object, + **__kwargs: object, +) -> NDArray[numpy.bool_] | NDArray[numpy.int32] | NDArray[numpy.float32]: ... diff --git a/tests/core/parameter_validation/test_parameter_clone.py b/tests/core/parameter_validation/test_parameter_clone.py index a14630e9a0..6c77b4bb0b 100644 --- a/tests/core/parameter_validation/test_parameter_clone.py +++ b/tests/core/parameter_validation/test_parameter_clone.py @@ -6,21 +6,20 @@ year = 2016 -def test_clone(): - path = os.path.join(BASE_DIR, 'filesystem_hierarchy') - parameters = ParameterNode('', directory_path = path) - parameters_at_instant = parameters('2016-01-01') +def test_clone() -> None: + path = os.path.join(BASE_DIR, "filesystem_hierarchy") + parameters = ParameterNode("", directory_path=path) + parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 clone = parameters.clone() - clone_at_instant = clone('2016-01-01') + clone_at_instant = clone("2016-01-01") assert clone_at_instant.node1.param == 1.0 assert id(clone) != id(parameters) assert id(clone.node1) != id(parameters.node1) assert id(clone.node1.param) != id(parameters.node1.param) -def test_clone_parameter(tax_benefit_system): - +def test_clone_parameter(tax_benefit_system) -> None: param = tax_benefit_system.parameters.taxes.income_tax_rate clone = param.clone() @@ -31,16 +30,16 @@ def test_clone_parameter(tax_benefit_system): assert clone.values_list == param.values_list -def test_clone_parameter_node(tax_benefit_system): +def test_clone_parameter_node(tax_benefit_system) -> None: node = tax_benefit_system.parameters.taxes clone = node.clone() assert clone is not node assert clone.income_tax_rate is not node.income_tax_rate - assert clone.children['income_tax_rate'] is not node.children['income_tax_rate'] + assert clone.children["income_tax_rate"] is not node.children["income_tax_rate"] -def test_clone_scale(tax_benefit_system): +def test_clone_scale(tax_benefit_system) -> None: scale = tax_benefit_system.parameters.taxes.social_security_contribution clone = scale.clone() @@ -48,7 +47,7 @@ def test_clone_scale(tax_benefit_system): assert clone.brackets[0].rate is not scale.brackets[0].rate -def test_deep_edit(tax_benefit_system): +def test_deep_edit(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters clone = parameters.clone() diff --git a/tests/core/parameter_validation/test_parameter_validation.py b/tests/core/parameter_validation/test_parameter_validation.py index 62b2b0c132..d3419312d2 100644 --- a/tests/core/parameter_validation/test_parameter_validation.py +++ b/tests/core/parameter_validation/test_parameter_validation.py @@ -1,15 +1,19 @@ -# -*- coding: utf-8 -*- - import os + import pytest -from openfisca_core.parameters import load_parameter_file, ParameterNode, ParameterParsingError + +from openfisca_core.parameters import ( + ParameterNode, + ParameterParsingError, + load_parameter_file, +) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) year = 2016 -def check_fails_with_message(file_name, keywords): - path = os.path.join(BASE_DIR, file_name) + '.yaml' +def check_fails_with_message(file_name, keywords) -> None: + path = os.path.join(BASE_DIR, file_name) + ".yaml" try: load_parameter_file(path, file_name) except ParameterParsingError as e: @@ -19,42 +23,65 @@ def check_fails_with_message(file_name, keywords): raise -@pytest.mark.parametrize("test", [ - ('indentation', {'Invalid YAML', 'indentation.yaml', 'line 2', 'mapping values are not allowed'}), - ("wrong_date", {"Error parsing parameter file", "Properties must be valid YYYY-MM-DD instants"}), - ('wrong_scale', {'Unexpected property', 'scale[1]', 'treshold'}), - ('wrong_value', {'not one of the allowed types', 'wrong_value[2015-12-01]', '1A'}), - ('unexpected_key_in_parameter', {'Unexpected property', 'unexpected_key'}), - ('wrong_type_in_parameter', {'must be of type object'}), - ('wrong_type_in_value_history', {'must be of type object'}), - ('unexpected_key_in_value_history', {'must be valid YYYY-MM-DD instants'}), - ('unexpected_key_in_value_at_instant', {'Unexpected property', 'unexpected_key'}), - ('unexpected_key_in_scale', {'Unexpected property', 'unexpected_key'}), - ('wrong_type_in_scale', {'must be of type object'}), - ('wrong_type_in_brackets', {'must be of type array'}), - ('wrong_type_in_bracket', {'must be of type object'}), - ('missing_value', {'missing', 'value'}), - ('duplicate_key', {'duplicate'}), - ]) -def test_parsing_errors(test): +@pytest.mark.parametrize( + "test", + [ + ( + "indentation", + { + "Invalid YAML", + "indentation.yaml", + "line 2", + "mapping values are not allowed", + }, + ), + ( + "wrong_date", + { + "Error parsing parameter file", + "Properties must be valid YYYY-MM-DD instants", + }, + ), + ("wrong_scale", {"Unexpected property", "scale[1]", "treshold"}), + ( + "wrong_value", + {"not one of the allowed types", "wrong_value[2015-12-01]", "1A"}, + ), + ("unexpected_key_in_parameter", {"Unexpected property", "unexpected_key"}), + ("wrong_type_in_parameter", {"must be of type object"}), + ("wrong_type_in_value_history", {"must be of type object"}), + ("unexpected_key_in_value_history", {"must be valid YYYY-MM-DD instants"}), + ( + "unexpected_key_in_value_at_instant", + {"Unexpected property", "unexpected_key"}, + ), + ("unexpected_key_in_scale", {"Unexpected property", "unexpected_key"}), + ("wrong_type_in_scale", {"must be of type object"}), + ("wrong_type_in_brackets", {"must be of type array"}), + ("wrong_type_in_bracket", {"must be of type object"}), + ("missing_value", {"missing", "value"}), + ("duplicate_key", {"duplicate"}), + ], +) +def test_parsing_errors(test) -> None: with pytest.raises(ParameterParsingError): check_fails_with_message(*test) -def test_array_type(): - path = os.path.join(BASE_DIR, 'array_type.yaml') - load_parameter_file(path, 'array_type') +def test_array_type() -> None: + path = os.path.join(BASE_DIR, "array_type.yaml") + load_parameter_file(path, "array_type") -def test_filesystem_hierarchy(): - path = os.path.join(BASE_DIR, 'filesystem_hierarchy') - parameters = ParameterNode('', directory_path = path) - parameters_at_instant = parameters('2016-01-01') +def test_filesystem_hierarchy() -> None: + path = os.path.join(BASE_DIR, "filesystem_hierarchy") + parameters = ParameterNode("", directory_path=path) + parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 -def test_yaml_hierarchy(): - path = os.path.join(BASE_DIR, 'yaml_hierarchy') - parameters = ParameterNode('', directory_path = path) - parameters_at_instant = parameters('2016-01-01') +def test_yaml_hierarchy() -> None: + path = os.path.join(BASE_DIR, "yaml_hierarchy") + parameters = ParameterNode("", directory_path=path) + parameters_at_instant = parameters("2016-01-01") assert parameters_at_instant.node1.param == 1.0 diff --git a/tests/core/parameters_date_indexing/__init__.py b/tests/core/parameters_date_indexing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/parameters_date_indexing/full_rate_age.yaml b/tests/core/parameters_date_indexing/full_rate_age.yaml new file mode 100644 index 0000000000..fa9377fec5 --- /dev/null +++ b/tests/core/parameters_date_indexing/full_rate_age.yaml @@ -0,0 +1,121 @@ +description: Full rate age +full_rate_age_by_birthdate: + description: Full rate age by birthdate + before_1951_07_01: + description: Born before 01/07/1951 + year: + description: Year + values: + 1983-04-01: + value: 65.0 + month: + description: Month + values: + 1983-04-01: + value: 0.0 + after_1951_07_01: + description: Born after 01/07/1951 + year: + description: Year + values: + 2011-07-01: + value: 65.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2011-07-01: + value: 4.0 + 1983-04-01: + value: null + after_1952_01_01: + description: Born after 01/01/1952 + year: + description: Year + values: + 2011-07-01: + value: 65.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 9.0 + 2011-07-01: + value: 8.0 + 1983-04-01: + value: null + after_1953_01_01: + description: Born after 01/01/1953 + year: + description: Year + values: + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 2.0 + 2011-07-01: + value: 0.0 + 1983-04-01: + value: null + after_1954_01_01: + description: Born after 01/01/1954 + year: + description: Year + values: + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 7.0 + 2011-07-01: + value: 4.0 + 1983-04-01: + value: null + after_1955_01_01: + description: Born after 01/01/1955 + year: + description: Year + values: + 2012-01-01: + value: 67.0 + 2011-07-01: + value: 66.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2012-01-01: + value: 0.0 + 2011-07-01: + value: 8.0 + 1983-04-01: + value: null + after_1956_01_01: + description: Born after 01/01/1956 + year: + description: Year + values: + 2011-07-01: + value: 67.0 + 1983-04-01: + value: null + month: + description: Month + values: + 2011-07-01: + value: 0.0 + 1983-04-01: + value: null diff --git a/tests/core/parameters_date_indexing/full_rate_required_duration.yml b/tests/core/parameters_date_indexing/full_rate_required_duration.yml new file mode 100644 index 0000000000..af394ec568 --- /dev/null +++ b/tests/core/parameters_date_indexing/full_rate_required_duration.yml @@ -0,0 +1,162 @@ +description: Required contribution duration for full rate +contribution_quarters_required_by_birthdate: + description: Contribution quarters required by birthdate + before_1934_01_01: + description: before 1934 + values: + 1983-01-01: + value: 150.0 + after_1934_01_01: + description: '1934-01-01' + values: + 1994-01-01: + value: 151.0 + 1983-01-01: + value: null + after_1935_01_01: + description: '1935-01-01' + values: + 1994-01-01: + value: 152.0 + 1983-01-01: + value: null + after_1936_01_01: + description: '1936-01-01' + values: + 1994-01-01: + value: 153.0 + 1983-01-01: + value: null + after_1937_01_01: + description: '1937-01-01' + values: + 1994-01-01: + value: 154.0 + 1983-01-01: + value: null + after_1938_01_01: + description: '1938-01-01' + values: + 1994-01-01: + value: 155.0 + 1983-01-01: + value: null + after_1939_01_01: + description: '1939-01-01' + values: + 1994-01-01: + value: 156.0 + 1983-01-01: + value: null + after_1940_01_01: + description: '1940-01-01' + values: + 1994-01-01: + value: 157.0 + 1983-01-01: + value: null + after_1941_01_01: + description: '1941-01-01' + values: + 1994-01-01: + value: 158.0 + 1983-01-01: + value: null + after_1942_01_01: + description: '1942-01-01' + values: + 1994-01-01: + value: 159.0 + 1983-01-01: + value: null + after_1943_01_01: + description: '1943-01-01' + values: + 1994-01-01: + value: 160.0 + 1983-01-01: + value: null + after_1949_01_01: + description: '1949-01-01' + values: + 2009-01-01: + value: 161.0 + 1983-01-01: + value: null + after_1950_01_01: + description: '1950-01-01' + values: + 2009-01-01: + value: 162.0 + 1983-01-01: + value: null + after_1951_01_01: + description: '1951-01-01' + values: + 2009-01-01: + value: 163.0 + 1983-01-01: + value: null + after_1952_01_01: + description: '1952-01-01' + values: + 2009-01-01: + value: 164.0 + 1983-01-01: + value: null + after_1953_01_01: + description: '1953-01-01' + values: + 2012-01-01: + value: 165.0 + 1983-01-01: + value: null + after_1955_01_01: + description: '1955-01-01' + values: + 2013-01-01: + value: 166.0 + 1983-01-01: + value: null + after_1958_01_01: + description: '1958-01-01' + values: + 2015-01-01: + value: 167.0 + 1983-01-01: + value: null + after_1961_01_01: + description: '1961-01-01' + values: + 2015-01-01: + value: 168.0 + 1983-01-01: + value: null + after_1964_01_01: + description: '1964-01-01' + values: + 2015-01-01: + value: 169.0 + 1983-01-01: + value: null + after_1967_01_01: + description: '1967-01-01' + values: + 2015-01-01: + value: 170.0 + 1983-01-01: + value: null + after_1970_01_01: + description: '1970-01-01' + values: + 2015-01-01: + value: 171.0 + 1983-01-01: + value: null + after_1973_01_01: + description: '1973-01-01' + values: + 2015-01-01: + value: 172.0 + 1983-01-01: + value: null diff --git a/tests/core/parameters_date_indexing/test_date_indexing.py b/tests/core/parameters_date_indexing/test_date_indexing.py new file mode 100644 index 0000000000..cefec26648 --- /dev/null +++ b/tests/core/parameters_date_indexing/test_date_indexing.py @@ -0,0 +1,48 @@ +import os + +import numpy + +from openfisca_core.parameters import ParameterNode +from openfisca_core.tools import assert_near + +from openfisca_core.model_api import * # noqa + +LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) + +parameters = ParameterNode(directory_path=LOCAL_DIR) + + +def get_message(error): + return error.args[0] + + +def test_on_leaf() -> None: + parameter_at_instant = parameters.full_rate_required_duration("1995-01-01") + birthdate = numpy.array( + ["1930-01-01", "1935-01-01", "1940-01-01", "1945-01-01"], + dtype="datetime64[D]", + ) + assert_near( + parameter_at_instant.contribution_quarters_required_by_birthdate[birthdate], + [150, 152, 157, 160], + ) + + +def test_on_node() -> None: + birthdate = numpy.array( + ["1950-01-01", "1953-01-01", "1956-01-01", "1959-01-01"], + dtype="datetime64[D]", + ) + parameter_at_instant = parameters.full_rate_age("2012-03-01") + node = parameter_at_instant.full_rate_age_by_birthdate[birthdate] + assert_near(node.year, [65, 66, 67, 67]) + assert_near(node.month, [0, 2, 0, 0]) + + +# def test_inhomogenous(): +# birthdate = numpy.array(['1930-01-01', '1935-01-01', '1940-01-01', '1945-01-01'], dtype = 'datetime64[D]') +# parameter_at_instant = parameters..full_rate_age('2011-01-01') +# parameter_at_instant.full_rate_age_by_birthdate[birthdate] +# with pytest.raises(ValueError) as error: +# parameter_at_instant.full_rate_age_by_birthdate[birthdate] +# assert "Cannot use fancy indexing on parameter node '.full_rate_age.full_rate_age_by_birthdate'" in get_message(error.value) diff --git a/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml b/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml new file mode 100644 index 0000000000..9894ae64aa --- /dev/null +++ b/tests/core/parameters_fancy_indexing/coefficient_de_minoration.yaml @@ -0,0 +1,135 @@ +description: Coefficient de minoration ARRCO +coefficient_minoration_en_fonction_distance_age_annulation_decote_en_annee: + description: Coefficient de minoration à l'Arrco en fonction de la distance à l'âge d'annulation de la décote (en année) + '-10': + description: '-10' + values: + 1965-01-01: + value: 0.43 + 1957-05-15: + value: null + '-9': + description: '-9' + values: + 1965-01-01: + value: 0.5 + 1957-05-15: + value: null + '-8': + description: '-8' + values: + 1965-01-01: + value: 0.57 + 1957-05-15: + value: null + '-7': + description: '-7' + values: + 1965-01-01: + value: 0.64 + 1957-05-15: + value: null + '-6': + description: '-6' + values: + 1965-01-01: + value: 0.71 + 1957-05-15: + value: null + '-5': + description: '-5' + values: + 1965-01-01: + value: 0.78 + 1957-05-15: + value: 0.75 + '-4': + description: '-4' + values: + 1965-01-01: + value: 0.83 + 1957-05-15: + value: 0.8 + '-3': + description: '-3' + values: + 1965-01-01: + value: 0.88 + 1957-05-15: + value: 0.85 + '-2': + description: '-2' + values: + 1965-01-01: + value: 0.92 + 1957-05-15: + value: 0.9 + '-1': + description: '-1' + values: + 1965-01-01: + value: 0.96 + 1957-05-15: + value: 0.95 + '0': + description: '0' + values: + 1965-01-01: + value: 1.0 + 1957-05-15: + value: 1.05 + '1': + description: '1' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.1 + '2': + description: '2' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.15 + '3': + description: '3' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.2 + '4': + description: '4' + values: + 1965-01-01: + value: null + 1957-05-15: + value: 1.25 + metadata: + order: + - '-10' + - '-9' + - '-8' + - '-7' + - '-6' + - '-5' + - '-4' + - '-3' + - '-2' + - '-1' + - '0' + - '1' + - '2' + - '3' + - '4' +metadata: + order: + - coefficient_minoration_en_fonction_distance_age_annulation_decote_en_annee + reference: + 1965-01-01: Article 18 de l'annexe A de l'Accord national interprofessionnel de retraite complémentaire du 8 décembre 1961 + 1957-05-15: Accord du 15/05/1957 pour la création de l'UNIRS + description_en: Penalty for early retirement ARRCO +documentation: | + Note: Le coefficient d'abattement (ou de majoration avant 1965) constitue une multiplication des droits de pension à l'arrco par le coefficient en question. Par exemple, un individu partant en retraite à 60 ans en 1960 touchait 75% de sa pension. A partir de 1983, une double condition d'âge et de durée d'assurance est instaurée: un individu ayant validé une durée égale à la durée d'assurance cible(voir onglet Trim_tx_plein_RG) partira sans abbattement, même s'il n'a pas atteint l'âge d'annulation de la décôte dans le régime général (voir onglet Age_ann_dec_RG). + Note : le coefficient de minoration est linéaire en nombre de trimestres, e.g. il est de 0,43 à AAD - 10 ans, de 0,4475 à AAD - 9 ans et 3 trimestres, de 0,465 à AAD - 9 ans et 2 trimestres, etc. diff --git a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py index d34eb00773..b7e7cf4e45 100644 --- a/tests/core/parameters_fancy_indexing/test_fancy_indexing.py +++ b/tests/core/parameters_fancy_indexing/test_fancy_indexing.py @@ -1,148 +1,177 @@ -# -*- coding: utf-8 -*- - import os import re -import numpy as np +import numpy import pytest - +from openfisca_core.indexed_enums import Enum +from openfisca_core.parameters import Parameter, ParameterNode, ParameterNotFound from openfisca_core.tools import assert_near -from openfisca_core.parameters import ParameterNode, Parameter, ParameterNotFound -from openfisca_core.model_api import * # noqa LOCAL_DIR = os.path.dirname(os.path.abspath(__file__)) -parameters = ParameterNode(directory_path = LOCAL_DIR) +parameters = ParameterNode(directory_path=LOCAL_DIR) -P = parameters.rate('2015-01-01') +P = parameters.rate("2015-01-01") def get_message(error): return error.args[0] -def test_on_leaf(): - zone = np.asarray(['z1', 'z2', 'z2', 'z1']) +def test_on_leaf() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) -def test_on_node(): - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) +def test_on_node() -> None: + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P.single[housing_occupancy_status] assert_near(node.z1, [100, 100, 300, 300]) - assert_near(node['z1'], [100, 100, 300, 300]) + assert_near(node["z1"], [100, 100, 300, 300]) -def test_double_fancy_indexing(): - zone = np.asarray(['z1', 'z2', 'z2', 'z1']) - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) +def test_double_fancy_indexing() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) assert_near(P.single[housing_occupancy_status][zone], [100, 200, 400, 300]) -def test_double_fancy_indexing_on_node(): - family_status = np.asarray(['single', 'couple', 'single', 'couple']) - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) +def test_double_fancy_indexing_on_node() -> None: + family_status = numpy.asarray(["single", "couple", "single", "couple"]) + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) node = P[family_status][housing_occupancy_status] assert_near(node.z1, [100, 500, 300, 700]) - assert_near(node['z1'], [100, 500, 300, 700]) + assert_near(node["z1"], [100, 500, 300, 700]) assert_near(node.z2, [200, 600, 400, 800]) - assert_near(node['z2'], [200, 600, 400, 800]) - - -def test_triple_fancy_indexing(): - family_status = np.asarray(['single', 'single', 'single', 'single', 'couple', 'couple', 'couple', 'couple']) - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant', 'owner', 'owner', 'tenant', 'tenant']) - zone = np.asarray(['z1', 'z2', 'z1', 'z2', 'z1', 'z2', 'z1', 'z2']) - assert_near(P[family_status][housing_occupancy_status][zone], [100, 200, 300, 400, 500, 600, 700, 800]) - - -def test_wrong_key(): - zone = np.asarray(['z1', 'z2', 'z2', 'toto']) + assert_near(node["z2"], [200, 600, 400, 800]) + + +def test_triple_fancy_indexing() -> None: + family_status = numpy.asarray( + [ + "single", + "single", + "single", + "single", + "couple", + "couple", + "couple", + "couple", + ], + ) + housing_occupancy_status = numpy.asarray( + ["owner", "owner", "tenant", "tenant", "owner", "owner", "tenant", "tenant"], + ) + zone = numpy.asarray(["z1", "z2", "z1", "z2", "z1", "z2", "z1", "z2"]) + assert_near( + P[family_status][housing_occupancy_status][zone], + [100, 200, 300, 400, 500, 600, 700, 800], + ) + + +def test_wrong_key() -> None: + zone = numpy.asarray(["z1", "z2", "z2", "toto"]) with pytest.raises(ParameterNotFound) as e: P.single.owner[zone] assert "'rate.single.owner.toto' was not found" in get_message(e.value) -def test_inhomogenous(): - parameters = ParameterNode(directory_path = LOCAL_DIR) - parameters.rate.couple.owner.add_child('toto', Parameter('toto', { - "values": { - "2015-01-01": { - "value": 1000 +def test_inhomogenous() -> None: + parameters = ParameterNode(directory_path=LOCAL_DIR) + parameters.rate.couple.owner.add_child( + "toto", + Parameter( + "toto", + { + "values": { + "2015-01-01": {"value": 1000}, }, - } - })) + }, + ), + ) - P = parameters.rate('2015-01-01') - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) + P = parameters.rate("2015-01-01") + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as error: P.couple[housing_occupancy_status] assert "'rate.couple.owner.toto' exists" in get_message(error.value) assert "'rate.couple.tenant.toto' doesn't" in get_message(error.value) -def test_inhomogenous_2(): - parameters = ParameterNode(directory_path = LOCAL_DIR) - parameters.rate.couple.tenant.add_child('toto', Parameter('toto', { - "values": { - "2015-01-01": { - "value": 1000 +def test_inhomogenous_2() -> None: + parameters = ParameterNode(directory_path=LOCAL_DIR) + parameters.rate.couple.tenant.add_child( + "toto", + Parameter( + "toto", + { + "values": { + "2015-01-01": {"value": 1000}, }, - } - })) + }, + ), + ) - P = parameters.rate('2015-01-01') - housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) + P = parameters.rate("2015-01-01") + housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) with pytest.raises(ValueError) as e: P.couple[housing_occupancy_status] assert "'rate.couple.tenant.toto' exists" in get_message(e.value) assert "'rate.couple.owner.toto' doesn't" in get_message(e.value) -def test_inhomogenous_3(): - parameters = ParameterNode(directory_path = LOCAL_DIR) - parameters.rate.couple.tenant.add_child('z4', ParameterNode('toto', data = { - 'amount': { - 'values': { - "2015-01-01": {'value': 550}, - "2016-01-01": {'value': 600} - } - } - })) +def test_inhomogenous_3() -> None: + parameters = ParameterNode(directory_path=LOCAL_DIR) + parameters.rate.couple.tenant.add_child( + "z4", + ParameterNode( + "toto", + data={ + "amount": { + "values": { + "2015-01-01": {"value": 550}, + "2016-01-01": {"value": 600}, + }, + }, + }, + ), + ) - P = parameters.rate('2015-01-01') - zone = np.asarray(['z1', 'z2', 'z2', 'z1']) + P = parameters.rate("2015-01-01") + zone = numpy.asarray(["z1", "z2", "z2", "z1"]) with pytest.raises(ValueError) as e: P.couple.tenant[zone] assert "'rate.couple.tenant.z4' is a node" in get_message(e.value) assert re.findall(r"'rate.couple.tenant.z(1|2|3)' is not", get_message(e.value)) -P_2 = parameters.local_tax('2015-01-01') +P_2 = parameters.local_tax("2015-01-01") -def test_with_properties_starting_by_number(): - city_code = np.asarray(['75012', '75007', '75015']) +def test_with_properties_starting_by_number() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) assert_near(P_2[city_code], [100, 300, 200]) -P_3 = parameters.bareme('2015-01-01') +P_3 = parameters.bareme("2015-01-01") -def test_with_bareme(): - city_code = np.asarray(['75012', '75007', '75015']) +def test_with_bareme() -> None: + city_code = numpy.asarray(["75012", "75007", "75015"]) with pytest.raises(NotImplementedError) as e: P_3[city_code] - assert re.findall(r"'bareme.7501\d' is a 'MarginalRateTaxScale'", get_message(e.value)) + assert re.findall( + r"'bareme.7501\d' is a 'MarginalRateTaxScale'", + get_message(e.value), + ) assert "has not been implemented" in get_message(e.value) -def test_with_enum(): - +def test_with_enum() -> None: class TypesZone(Enum): z1 = "Zone 1" z2 = "Zone 2" - zone = np.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) + zone = numpy.asarray([TypesZone.z1, TypesZone.z2, TypesZone.z2, TypesZone.z1]) assert_near(P.single.owner[zone], [100, 200, 200, 100]) diff --git a/tests/core/tax_scales/test_abstract_rate_tax_scale.py b/tests/core/tax_scales/test_abstract_rate_tax_scale.py index 3d284a49e9..c966aa30f3 100644 --- a/tests/core/tax_scales/test_abstract_rate_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_rate_tax_scale.py @@ -1,9 +1,9 @@ -from openfisca_core import taxscales - import pytest +from openfisca_core import taxscales + -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractRateTaxScale() - assert type(result) == taxscales.AbstractRateTaxScale + assert isinstance(result, taxscales.AbstractRateTaxScale) diff --git a/tests/core/tax_scales/test_abstract_tax_scale.py b/tests/core/tax_scales/test_abstract_tax_scale.py index f6834e7dc7..aad04d58ed 100644 --- a/tests/core/tax_scales/test_abstract_tax_scale.py +++ b/tests/core/tax_scales/test_abstract_tax_scale.py @@ -1,9 +1,9 @@ -from openfisca_core import taxscales - import pytest +from openfisca_core import taxscales + -def test_abstract_tax_scale(): +def test_abstract_tax_scale() -> None: with pytest.warns(DeprecationWarning): result = taxscales.AbstractTaxScale() - assert type(result) == taxscales.AbstractTaxScale + assert isinstance(result, taxscales.AbstractTaxScale) diff --git a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py index 83153024c7..6205d6de9b 100644 --- a/tests/core/tax_scales/test_linear_average_rate_tax_scale.py +++ b/tests/core/tax_scales/test_linear_average_rate_tax_scale.py @@ -1,12 +1,10 @@ import numpy - -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import taxscales, tools + -def test_bracket_indices(): +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -18,31 +16,31 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, factor = 2.0) + result = tax_scale.bracket_indices(tax_base, factor=2.0) tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, round_decimals = 0) + result = tax_scale.bracket_indices(tax_base, round_decimals=0) tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -53,7 +51,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() @@ -61,7 +59,7 @@ def test_bracket_indices_without_brackets(): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -71,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_to_marginal(): +def test_to_marginal() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -81,9 +79,9 @@ def test_to_marginal(): result = tax_scale.to_marginal() assert result.thresholds == [0, 1, 2] - tools.assert_near(result.rates, [0.1, 0.3, 0.2], absolute_error_margin = 0) + tools.assert_near(result.rates, [0.1, 0.3, 0.2], absolute_error_margin=0) tools.assert_near( result.calc(tax_base), [0.1, 0.25, 0.4, 0.5], - absolute_error_margin = 0, - ) + absolute_error_margin=0, + ) diff --git a/tests/core/tax_scales/test_marginal_amount_tax_scale.py b/tests/core/tax_scales/test_marginal_amount_tax_scale.py index 7582d725b4..0a3275c901 100644 --- a/tests/core/tax_scales/test_marginal_amount_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_amount_tax_scale.py @@ -1,12 +1,8 @@ from numpy import array - -from openfisca_core import parameters -from openfisca_core import periods -from openfisca_core import taxscales -from openfisca_core import tools - from pytest import fixture +from openfisca_core import parameters, periods, taxscales, tools + @fixture def data(): @@ -16,13 +12,15 @@ def data(): "brackets": [ { "threshold": {"2017-10-01": {"value": 0.23}}, - "amount": {"2017-10-01": {"value": 6}, }, - } - ], - } + "amount": { + "2017-10-01": {"value": 6}, + }, + }, + ], + } -def test_calc(): +def test_calc() -> None: tax_base = array([1, 8, 10]) tax_scale = taxscales.MarginalAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -34,7 +32,7 @@ def test_calc(): # TODO: move, as we're testing Scale, not MarginalAmountTaxScale -def test_dispatch_scale_type_on_creation(data): +def test_dispatch_scale_type_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) diff --git a/tests/core/tax_scales/test_marginal_rate_tax_scale.py b/tests/core/tax_scales/test_marginal_rate_tax_scale.py index 1688e7e3cc..7696e95fc4 100644 --- a/tests/core/tax_scales/test_marginal_rate_tax_scale.py +++ b/tests/core/tax_scales/test_marginal_rate_tax_scale.py @@ -1,12 +1,10 @@ import numpy - -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import taxscales, tools -def test_bracket_indices(): + +def test_bracket_indices() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -18,31 +16,31 @@ def test_bracket_indices(): tools.assert_near(result, [0, 0, 0, 1, 1, 2]) -def test_bracket_indices_with_factor(): +def test_bracket_indices_with_factor() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, factor = 2.0) + result = tax_scale.bracket_indices(tax_base, factor=2.0) tools.assert_near(result, [0, 0, 0, 0, 1, 1]) -def test_bracket_indices_with_round_decimals(): +def test_bracket_indices_with_round_decimals() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(2, 0) tax_scale.add_bracket(4, 0) - result = tax_scale.bracket_indices(tax_base, round_decimals = 0) + result = tax_scale.bracket_indices(tax_base, round_decimals=0) tools.assert_near(result, [0, 0, 1, 1, 2, 2]) -def test_bracket_indices_without_tax_base(): +def test_bracket_indices_without_tax_base() -> None: tax_base = numpy.array([]) tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) @@ -53,7 +51,7 @@ def test_bracket_indices_without_tax_base(): tax_scale.bracket_indices(tax_base) -def test_bracket_indices_without_brackets(): +def test_bracket_indices_without_brackets() -> None: tax_base = numpy.array([0, 1, 2, 3, 4, 5]) tax_scale = taxscales.LinearAverageRateTaxScale() @@ -61,7 +59,7 @@ def test_bracket_indices_without_brackets(): tax_scale.bracket_indices(tax_base) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.LinearAverageRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) @@ -71,7 +69,7 @@ def test_to_dict(): assert result == {"0": 0.0, "100": 0.1} -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5, 3.0, 4.0]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -84,11 +82,11 @@ def test_calc(): tools.assert_near( result, [0, 0.05, 0.1, 0.2, 0.3, 0.3], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_without_round(): +def test_calc_without_round() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -99,56 +97,56 @@ def test_calc_without_round(): tools.assert_near( result, [10, 10.02, 10.0002, 10.06, 10.0006, 10.05, 10.0005], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_when_round_is_1(): +def test_calc_when_round_is_1() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) - result = tax_scale.calc(tax_base, round_base_decimals = 1) + result = tax_scale.calc(tax_base, round_base_decimals=1) tools.assert_near( result, [10, 10.0, 10.0, 10.1, 10.0, 10, 10.0], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_when_round_is_2(): +def test_calc_when_round_is_2() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) - result = tax_scale.calc(tax_base, round_base_decimals = 2) + result = tax_scale.calc(tax_base, round_base_decimals=2) tools.assert_near( result, [10, 10.02, 10.0, 10.06, 10.00, 10.05, 10], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_calc_when_round_is_3(): +def test_calc_when_round_is_3() -> None: tax_base = numpy.array([200, 200.2, 200.002, 200.6, 200.006, 200.5, 200.005]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) tax_scale.add_bracket(100, 0.1) - result = tax_scale.calc(tax_base, round_base_decimals = 3) + result = tax_scale.calc(tax_base, round_base_decimals=3) tools.assert_near( result, [10, 10.02, 10.0, 10.06, 10.001, 10.05, 10], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) -def test_marginal_rates(): +def test_marginal_rates() -> None: tax_base = numpy.array([0, 10, 50, 125, 250]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -160,7 +158,7 @@ def test_marginal_rates(): tools.assert_near(result, [0, 0, 0, 0.1, 0.2]) -def test_inverse(): +def test_inverse() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -173,7 +171,7 @@ def test_inverse(): tools.assert_near(result.calc(net_tax_base), gross_tax_base, 1e-15) -def test_scale_tax_scales(): +def test_scale_tax_scales() -> None: tax_base = numpy.array([1, 2, 3]) tax_base_scale = 12.345 scaled_tax_base = tax_base * tax_base_scale @@ -187,7 +185,7 @@ def test_scale_tax_scales(): tools.assert_near(result.thresholds, scaled_tax_base) -def test_inverse_scaled_marginal_tax_scales(): +def test_inverse_scaled_marginal_tax_scales() -> None: gross_tax_base = numpy.array([1, 2, 3, 4, 5, 6]) gross_tax_base_scale = 12.345 scaled_gross_tax_base = gross_tax_base * gross_tax_base_scale @@ -196,17 +194,16 @@ def test_inverse_scaled_marginal_tax_scales(): tax_scale.add_bracket(1, 0.1) tax_scale.add_bracket(3, 0.05) scaled_tax_scale = tax_scale.scale_tax_scales(gross_tax_base_scale) - scaled_net_tax_base = ( - + scaled_gross_tax_base - - scaled_tax_scale.calc(scaled_gross_tax_base) - ) + scaled_net_tax_base = +scaled_gross_tax_base - scaled_tax_scale.calc( + scaled_gross_tax_base, + ) result = scaled_tax_scale.inverse() tools.assert_near(result.calc(scaled_net_tax_base), scaled_gross_tax_base, 1e-13) -def test_to_average(): +def test_to_average() -> None: tax_base = numpy.array([1, 1.5, 2, 2.5]) tax_scale = taxscales.MarginalRateTaxScale() tax_scale.add_bracket(0, 0) @@ -221,5 +218,33 @@ def test_to_average(): tools.assert_near( result.calc(tax_base), [0, 0.0375, 0.1, 0.125], - absolute_error_margin = 1e-10, - ) + absolute_error_margin=1e-10, + ) + + +def test_rate_from_bracket_indice() -> None: + tax_base = numpy.array([0, 1_000, 1_500, 50_000]) + tax_scale = taxscales.MarginalRateTaxScale() + tax_scale.add_bracket(0, 0) + tax_scale.add_bracket(400, 0.1) + tax_scale.add_bracket(15_000, 0.4) + + bracket_indice = tax_scale.bracket_indices(tax_base) + result = tax_scale.rate_from_bracket_indice(bracket_indice) + + assert isinstance(result, numpy.ndarray) + assert (result == numpy.array([0.0, 0.1, 0.1, 0.4])).all() + + +def test_rate_from_tax_base() -> None: + tax_base = numpy.array([0, 3_000, 15_500, 500_000]) + tax_scale = taxscales.MarginalRateTaxScale() + tax_scale.add_bracket(0, 0) + tax_scale.add_bracket(400, 0.1) + tax_scale.add_bracket(15_000, 0.4) + tax_scale.add_bracket(200_000, 0.6) + + result = tax_scale.rate_from_tax_base(tax_base) + + assert isinstance(result, numpy.ndarray) + assert (result == numpy.array([0.0, 0.1, 0.4, 0.6])).all() diff --git a/tests/core/tax_scales/test_rate_tax_scale_like.py b/tests/core/tax_scales/test_rate_tax_scale_like.py new file mode 100644 index 0000000000..9f5bc61286 --- /dev/null +++ b/tests/core/tax_scales/test_rate_tax_scale_like.py @@ -0,0 +1,17 @@ +import numpy + +from openfisca_core import taxscales + + +def test_threshold_from_tax_base() -> None: + tax_base = numpy.array([0, 33_000, 500, 400_000]) + tax_scale = taxscales.LinearAverageRateTaxScale() + tax_scale.add_bracket(0, 0) + tax_scale.add_bracket(400, 0.1) + tax_scale.add_bracket(15_000, 0.4) + tax_scale.add_bracket(200_000, 0.6) + + result = tax_scale.threshold_from_tax_base(tax_base) + + assert isinstance(result, numpy.ndarray) + assert (result == numpy.array([0, 15_000, 400, 200_000])).all() diff --git a/tests/core/tax_scales/test_single_amount_tax_scale.py b/tests/core/tax_scales/test_single_amount_tax_scale.py index c5e6483a7d..2b384f6374 100644 --- a/tests/core/tax_scales/test_single_amount_tax_scale.py +++ b/tests/core/tax_scales/test_single_amount_tax_scale.py @@ -1,12 +1,8 @@ import numpy - -from openfisca_core import parameters -from openfisca_core import periods -from openfisca_core import taxscales -from openfisca_core import tools - from pytest import fixture +from openfisca_core import parameters, periods, taxscales, tools + @fixture def data(): @@ -16,17 +12,19 @@ def data(): "type": "single_amount", "threshold_unit": "currency-EUR", "rate_unit": "/1", - }, + }, "brackets": [ { "threshold": {"2017-10-01": {"value": 0.23}}, - "amount": {"2017-10-01": {"value": 6}, }, - } - ], - } + "amount": { + "2017-10-01": {"value": 6}, + }, + }, + ], + } -def test_calc(): +def test_calc() -> None: tax_base = numpy.array([1, 8, 10]) tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) @@ -37,7 +35,7 @@ def test_calc(): tools.assert_near(result, [0, 0.23, 0.29]) -def test_to_dict(): +def test_to_dict() -> None: tax_scale = taxscales.SingleAmountTaxScale() tax_scale.add_bracket(6, 0.23) tax_scale.add_bracket(9, 0.29) @@ -48,7 +46,7 @@ def test_to_dict(): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_thresholds_on_creation(data): +def test_assign_thresholds_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -59,7 +57,7 @@ def test_assign_thresholds_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_assign_amounts_on_creation(data): +def test_assign_amounts_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) scale_at_instant = scale.get_at_instant(first_jan) @@ -70,7 +68,7 @@ def test_assign_amounts_on_creation(data): # TODO: move, as we're testing Scale, not SingleAmountTaxScale -def test_dispatch_scale_type_on_creation(data): +def test_dispatch_scale_type_on_creation(data) -> None: scale = parameters.Scale("amount_scale", data, "") first_jan = periods.Instant((2017, 11, 1)) diff --git a/tests/core/tax_scales/test_tax_scales_commons.py b/tests/core/tax_scales/test_tax_scales_commons.py index d45bdd894a..544e5a07fe 100644 --- a/tests/core/tax_scales/test_tax_scales_commons.py +++ b/tests/core/tax_scales/test_tax_scales_commons.py @@ -1,32 +1,30 @@ -from openfisca_core import parameters -from openfisca_core import taxscales -from openfisca_core import tools - import pytest +from openfisca_core import parameters, taxscales, tools + @pytest.fixture def node(): return parameters.ParameterNode( "baremes", - data = { + data={ "health": { "brackets": [ {"rate": {"2015-01-01": 0.05}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.10}, "threshold": {"2015-01-01": 2000}}, - ] - }, + ], + }, "retirement": { "brackets": [ {"rate": {"2015-01-01": 0.02}, "threshold": {"2015-01-01": 0}}, {"rate": {"2015-01-01": 0.04}, "threshold": {"2015-01-01": 3000}}, - ] - }, + ], }, - )(2015) + }, + )(2015) -def test_combine_tax_scales(node): +def test_combine_tax_scales(node) -> None: result = taxscales.combine_tax_scales(node) tools.assert_near(result.thresholds, [0, 2000, 3000]) diff --git a/tests/core/test_axes.py b/tests/core/test_axes.py index f106a82a5b..11590daf51 100644 --- a/tests/core/test_axes.py +++ b/tests/core/test_axes.py @@ -1,192 +1,338 @@ import pytest +from openfisca_core import errors from openfisca_core.simulations import SimulationBuilder from openfisca_core.tools import test_runner - # With periods -def test_add_axis_without_period(persons): +def test_add_axis_without_period(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.set_default_period('2018-11') - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000}) + simulation_builder.set_default_period("2018-11") + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) # With variables -def test_add_axis_on_a_non_existing_variable(persons): +def test_add_axis_on_a_non_existing_variable(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'ubi', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "ubi", "min": 0, "max": 3000, "period": "2018-11"}, + ) with pytest.raises(KeyError): simulation_builder.expand_axes() -def test_add_axis_on_an_existing_variable_with_input(persons): +def test_add_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {'salary': {'2018-11': 1000}}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {"salary": {"2018-11": 1000}}}, + ) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) - assert simulation_builder.get_count('persons') == 3 - assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) + assert simulation_builder.get_count("persons") == 3 + assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] # With entities -def test_add_axis_on_persons(persons): +def test_add_axis_on_persons(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) - assert simulation_builder.get_count('persons') == 3 - assert simulation_builder.get_ids('persons') == ['Alicia0', 'Alicia1', 'Alicia2'] + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) + assert simulation_builder.get_count("persons") == 3 + assert simulation_builder.get_ids("persons") == ["Alicia0", "Alicia1", "Alicia2"] -def test_add_two_axes(persons): +def test_add_two_axes(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 1000, 2000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000], + ) + assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( + [0, 1000, 2000], + ) -def test_add_axis_with_group(persons): +def test_add_axis_with_group(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11', 'index': 1}) + simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_parallel_axis( + { + "count": 2, + "name": "salary", + "min": 0, + "max": 3000, + "period": "2018-11", + "index": 1, + }, + ) simulation_builder.expand_axes() - assert simulation_builder.get_count('persons') == 4 - assert simulation_builder.get_ids('persons') == ['Alicia0', 'Javier1', 'Alicia2', 'Javier3'] - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 0, 3000, 3000]) - - -def test_add_axis_with_group_int_period(persons): + assert simulation_builder.get_count("persons") == 4 + assert simulation_builder.get_ids("persons") == [ + "Alicia0", + "Javier1", + "Alicia2", + "Javier3", + ] + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 0, 3000, 3000], + ) + + +def test_add_axis_with_group_int_period(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018}) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': 2018, 'index': 1}) + simulation_builder.add_person_entity(persons, {"Alicia": {}, "Javier": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": 2018}, + ) + simulation_builder.add_parallel_axis( + { + "count": 2, + "name": "salary", + "min": 0, + "max": 3000, + "period": 2018, + "index": 1, + }, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018') == pytest.approx([0, 0, 3000, 3000]) + assert simulation_builder.get_input("salary", "2018") == pytest.approx( + [0, 0, 3000, 3000], + ) -def test_add_axis_on_households(persons, households): +def test_add_axis_on_households(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia', 'Javier']}, - 'houseb': {'parents': ['Tom']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia", "Javier"]}, + "houseb": {"parents": ["Tom"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_count('households') == 4 - assert simulation_builder.get_ids('households') == ['housea0', 'houseb1', 'housea2', 'houseb3'] - assert simulation_builder.get_input('rent', '2018-11') == pytest.approx([0, 0, 3000, 0]) - - -def test_axis_on_group_expands_persons(persons, households): + assert simulation_builder.get_count("households") == 4 + assert simulation_builder.get_ids("households") == [ + "housea0", + "houseb1", + "housea2", + "houseb3", + ] + assert simulation_builder.get_input("rent", "2018-11") == pytest.approx( + [0, 0, 3000, 0], + ) + + +def test_axis_on_group_expands_persons(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia', 'Javier']}, - 'houseb': {'parents': ['Tom']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia", "Javier"]}, + "houseb": {"parents": ["Tom"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_count('persons') == 6 + assert simulation_builder.get_count("persons") == 6 -def test_add_axis_distributes_roles(persons, households): +def test_add_axis_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia']}, - 'houseb': {'parents': ['Tom'], 'children': ['Javier']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia"]}, + "houseb": {"parents": ["Tom"], "children": ["Javier"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'child', 'parent', 'parent', 'child', 'parent'] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "child", + "parent", + "parent", + "child", + "parent", + ] -def test_add_axis_on_persons_distributes_roles(persons, households): +def test_add_axis_on_persons_distributes_roles(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia']}, - 'houseb': {'parents': ['Tom'], 'children': ['Javier']}, - }) - simulation_builder.register_variable('salary', persons) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia"]}, + "houseb": {"parents": ["Tom"], "children": ["Javier"]}, + }, + ) + simulation_builder.register_variable("salary", persons) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'child', 'parent', 'parent', 'child', 'parent'] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "child", + "parent", + "parent", + "child", + "parent", + ] -def test_add_axis_distributes_memberships(persons, households): +def test_add_axis_distributes_memberships(persons, households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}, 'Javier': {}, 'Tom': {}}) - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Tom'], households, { - 'housea': {'parents': ['Alicia']}, - 'houseb': {'parents': ['Tom'], 'children': ['Javier']}, - }) - simulation_builder.register_variable('rent', households) - simulation_builder.add_parallel_axis({'count': 2, 'name': 'rent', 'min': 0, 'max': 3000, 'period': '2018-11'}) + simulation_builder.add_person_entity( + persons, + {"Alicia": {}, "Javier": {}, "Tom": {}}, + ) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Tom"], + households, + { + "housea": {"parents": ["Alicia"]}, + "houseb": {"parents": ["Tom"], "children": ["Javier"]}, + }, + ) + simulation_builder.register_variable("rent", households) + simulation_builder.add_parallel_axis( + {"count": 2, "name": "rent", "min": 0, "max": 3000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_memberships('households') == [0, 1, 1, 2, 3, 3] + assert simulation_builder.get_memberships("households") == [0, 1, 1, 2, 3, 3] -def test_add_perpendicular_axes(persons): +def test_add_perpendicular_axes(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, {'Alicia': {}}) - simulation_builder.register_variable('salary', persons) - simulation_builder.register_variable('pension', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) + simulation_builder.add_person_entity(persons, {"Alicia": {}}) + simulation_builder.register_variable("salary", persons) + simulation_builder.register_variable("pension", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_perpendicular_axis( + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000, 0, 1500, 3000], + ) + assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( + [0, 0, 0, 2000, 2000, 2000], + ) -def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons): +def test_add_perpendicular_axis_on_an_existing_variable_with_input(persons) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_person_entity(persons, { - 'Alicia': { - 'salary': {'2018-11': 1000}, - 'pension': {'2018-11': 1000}, + simulation_builder.add_person_entity( + persons, + { + "Alicia": { + "salary": {"2018-11": 1000}, + "pension": {"2018-11": 1000}, }, - },) - simulation_builder.register_variable('salary', persons) - simulation_builder.register_variable('pension', persons) - simulation_builder.add_parallel_axis({'count': 3, 'name': 'salary', 'min': 0, 'max': 3000, 'period': '2018-11'}) - simulation_builder.add_perpendicular_axis({'count': 2, 'name': 'pension', 'min': 0, 'max': 2000, 'period': '2018-11'}) + }, + ) + simulation_builder.register_variable("salary", persons) + simulation_builder.register_variable("pension", persons) + simulation_builder.add_parallel_axis( + {"count": 3, "name": "salary", "min": 0, "max": 3000, "period": "2018-11"}, + ) + simulation_builder.add_perpendicular_axis( + {"count": 2, "name": "pension", "min": 0, "max": 2000, "period": "2018-11"}, + ) simulation_builder.expand_axes() - assert simulation_builder.get_input('salary', '2018-11') == pytest.approx([0, 1500, 3000, 0, 1500, 3000]) - assert simulation_builder.get_input('pension', '2018-11') == pytest.approx([0, 0, 0, 2000, 2000, 2000]) + assert simulation_builder.get_input("salary", "2018-11") == pytest.approx( + [0, 1500, 3000, 0, 1500, 3000], + ) + assert simulation_builder.get_input("pension", "2018-11") == pytest.approx( + [0, 0, 0, 2000, 2000, 2000], + ) -# Integration test +# Integration tests -def test_simulation_with_axes(tax_benefit_system): +def test_simulation_with_axes(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {salary: {2018-11: 0}} @@ -207,5 +353,31 @@ def test_simulation_with_axes(tax_benefit_system): """ data = test_runner.yaml.safe_load(input_yaml) simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data) - assert simulation.get_array('salary', '2018-11') == pytest.approx([0, 0, 0, 0, 0, 0]) - assert simulation.get_array('rent', '2018-11') == pytest.approx([0, 0, 3000, 0]) + assert simulation.get_array("salary", "2018-11") == pytest.approx( + [0, 0, 0, 0, 0, 0], + ) + assert simulation.get_array("rent", "2018-11") == pytest.approx([0, 0, 3000, 0]) + + +# Test for missing group entities with build_from_entities() + + +def test_simulation_with_axes_missing_entities(tax_benefit_system) -> None: + input_yaml = """ + persons: + Alicia: {salary: {2018-11: 0}} + Javier: {} + Tom: {} + axes: + - + - count: 2 + name: rent + min: 0 + max: 3000 + period: 2018-11 + """ + data = test_runner.yaml.safe_load(input_yaml) + with pytest.raises(errors.SituationParsingError) as error: + SimulationBuilder().build_from_dict(tax_benefit_system, data) + assert "In order to expand over axes" in error.value() + assert "all group entities and roles must be fully specified" in error.value() diff --git a/tests/core/test_calculate_output.py b/tests/core/test_calculate_output.py index 6a11a27d84..54d868ba92 100644 --- a/tests/core/test_calculate_output.py +++ b/tests/core/test_calculate_output.py @@ -2,57 +2,67 @@ from openfisca_country_template import entities, situation_examples -from openfisca_core import periods, simulations, tools +from openfisca_core import simulations, tools +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable class simple_variable(Variable): entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH value_type = int class variable_with_calculate_output_add(Variable): entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH value_type = int calculate_output = simulations.calculate_output_add class variable_with_calculate_output_divide(Variable): entity = entities.Person - definition_period = periods.YEAR + definition_period = DateUnit.YEAR value_type = int calculate_output = simulations.calculate_output_divide -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( simple_variable, variable_with_calculate_output_add, - variable_with_calculate_output_divide - ) + variable_with_calculate_output_divide, + ) @pytest.fixture def simulation(tax_benefit_system): - return SimulationBuilder().build_from_entities(tax_benefit_system, situation_examples.single) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.single, + ) -def test_calculate_output_default(simulation): +def test_calculate_output_default(simulation) -> None: with pytest.raises(ValueError): - simulation.calculate_output('simple_variable', 2017) + simulation.calculate_output("simple_variable", 2017) -def test_calculate_output_add(simulation): - simulation.set_input('variable_with_calculate_output_add', '2017-01', [10]) - simulation.set_input('variable_with_calculate_output_add', '2017-05', [20]) - simulation.set_input('variable_with_calculate_output_add', '2017-12', [70]) - tools.assert_near(simulation.calculate_output('variable_with_calculate_output_add', 2017), 100) +def test_calculate_output_add(simulation) -> None: + simulation.set_input("variable_with_calculate_output_add", "2017-01", [10]) + simulation.set_input("variable_with_calculate_output_add", "2017-05", [20]) + simulation.set_input("variable_with_calculate_output_add", "2017-12", [70]) + tools.assert_near( + simulation.calculate_output("variable_with_calculate_output_add", 2017), + 100, + ) -def test_calculate_output_divide(simulation): - simulation.set_input('variable_with_calculate_output_divide', 2017, [12000]) - tools.assert_near(simulation.calculate_output('variable_with_calculate_output_divide', '2017-06'), 1000) +def test_calculate_output_divide(simulation) -> None: + simulation.set_input("variable_with_calculate_output_divide", 2017, [12000]) + tools.assert_near( + simulation.calculate_output("variable_with_calculate_output_divide", "2017-06"), + 1000, + ) diff --git a/tests/core/test_countries.py b/tests/core/test_countries.py index aeb4d762c7..d206a8cb35 100644 --- a/tests/core/test_countries.py +++ b/tests/core/test_countries.py @@ -2,55 +2,56 @@ from openfisca_core import periods, populations, tools from openfisca_core.errors import VariableNameConflictError, VariableNotFoundError +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable PERIOD = periods.period("2016-01") -@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect = True) -def test_input_variable(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) +def test_input_variable(simulation) -> None: result = simulation.calculate("salary", PERIOD) - tools.assert_near(result, [2000], absolute_error_margin = 0.01) + tools.assert_near(result, [2000], absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect = True) -def test_basic_calculation(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 2000}, PERIOD)], indirect=True) +def test_basic_calculation(simulation) -> None: result = simulation.calculate("income_tax", PERIOD) - tools.assert_near(result, [300], absolute_error_margin = 0.01) + tools.assert_near(result, [300], absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect = True) -def test_calculate_add(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 24000}, PERIOD)], indirect=True) +def test_calculate_add(simulation) -> None: result = simulation.calculate_add("income_tax", PERIOD) - tools.assert_near(result, [3600], absolute_error_margin = 0.01) + tools.assert_near(result, [3600], absolute_error_margin=0.01) @pytest.mark.parametrize( "simulation", [({"accommodation_size": 100, "housing_occupancy_status": "tenant"}, PERIOD)], - indirect = True, - ) -def test_calculate_divide(simulation): + indirect=True, +) +def test_calculate_divide(simulation) -> None: result = simulation.calculate_divide("housing_tax", PERIOD) - tools.assert_near(result, [1000 / 12.], absolute_error_margin = 0.01) + tools.assert_near(result, [1000 / 12.0], absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect = True) -def test_bareme(simulation): +@pytest.mark.parametrize("simulation", [({"salary": 20000}, PERIOD)], indirect=True) +def test_bareme(simulation) -> None: result = simulation.calculate("social_security_contribution", PERIOD) expected = [0.02 * 6000 + 0.06 * 6400 + 0.12 * 7600] - tools.assert_near(result, expected, absolute_error_margin = 0.01) + tools.assert_near(result, expected, absolute_error_margin=0.01) -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_non_existing_variable(simulation): +@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) +def test_non_existing_variable(simulation) -> None: with pytest.raises(VariableNotFoundError): simulation.calculate("non_existent_variable", PERIOD) -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_calculate_variable_with_wrong_definition_period(simulation): +@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) +def test_calculate_variable_with_wrong_definition_period(simulation) -> None: year = str(PERIOD.this_year) with pytest.raises(ValueError) as error: @@ -60,30 +61,28 @@ def test_calculate_variable_with_wrong_definition_period(simulation): expected_words = ["period", year, "month", "basic_income", "ADD"] for word in expected_words: - assert word in error_message, f"Expected '{word}' in error message '{error_message}'" + assert ( + word in error_message + ), f"Expected '{word}' in error message '{error_message}'" -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_divide_option_on_month_defined_variable(simulation): - with pytest.raises(ValueError): - simulation.person("disposable_income", PERIOD, options = [populations.DIVIDE]) - - -@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect = True) -def test_divide_option_with_complex_period(simulation): +@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True) +def test_divide_option_with_complex_period(simulation) -> None: quarter = PERIOD.last_3_months with pytest.raises(ValueError) as error: - simulation.household("housing_tax", quarter, options = [populations.DIVIDE]) + simulation.household("housing_tax", quarter, options=[populations.DIVIDE]) error_message = str(error.value) - expected_words = ["DIVIDE", "one-year", "one-month", "period"] + expected_words = ["Can't", "calculate", "month", "year"] for word in expected_words: - assert word in error_message, f"Expected '{word}' in error message '{error_message}'" + assert ( + word in error_message + ), f"Expected '{word}' in error message '{error_message}'" -def test_input_with_wrong_period(tax_benefit_system): +def test_input_with_wrong_period(tax_benefit_system) -> None: year = str(PERIOD.this_year) variables = {"basic_income": {year: 12000}} simulation_builder = SimulationBuilder() @@ -93,7 +92,7 @@ def test_input_with_wrong_period(tax_benefit_system): simulation_builder.build_from_variables(tax_benefit_system, variables) -def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): +def test_variable_with_reference(make_simulation, isolated_tax_benefit_system) -> None: variables = {"salary": 4000} simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -102,10 +101,10 @@ def test_variable_with_reference(make_simulation, isolated_tax_benefit_system): assert result > 0 class disposable_income(Variable): - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() isolated_tax_benefit_system.update_variable(disposable_income) simulation = make_simulation(isolated_tax_benefit_system, variables, PERIOD) @@ -115,14 +114,13 @@ def formula(household, period): assert result == 0 -def test_variable_name_conflict(tax_benefit_system): - +def test_variable_name_conflict(tax_benefit_system) -> None: class disposable_income(Variable): reference = "disposable_income" - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + def formula(self, period): + return self.empty_array() with pytest.raises(VariableNameConflictError): tax_benefit_system.add_variable(disposable_income) diff --git a/tests/core/test_cycles.py b/tests/core/test_cycles.py index 1c4361ded2..acb08c6424 100644 --- a/tests/core/test_cycles.py +++ b/tests/core/test_cycles.py @@ -4,13 +4,14 @@ from openfisca_core import periods, tools from openfisca_core.errors import CycleError +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable @pytest.fixture def reference_period(): - return periods.period('2013-01') + return periods.period("2013-01") @pytest.fixture @@ -22,38 +23,38 @@ def simulation(tax_benefit_system): class variable1(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable2', period) + def formula(self, period): + return self("variable2", period) class variable2(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable1', period) + def formula(self, period): + return self("variable1", period) # 3 <--> 4 with a period offset class variable3(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable4', period.last_month) + def formula(self, period): + return self("variable4", period.last_month) class variable4(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('variable3', period) + def formula(self, period): + return self("variable3", period) # 5 -f-> 6 with a period offset @@ -61,30 +62,30 @@ def formula(person, period): class variable5(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - variable6 = person('variable6', period.last_month) + def formula(self, period): + variable6 = self("variable6", period.last_month) return 5 + variable6 class variable6(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person('variable5', period) + def formula(self, period): + variable5 = self("variable5", period) return 6 + variable5 class variable7(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - variable5 = person('variable5', period) + def formula(self, period): + variable5 = self("variable5", period) return 7 + variable5 @@ -92,17 +93,16 @@ def formula(person, period): class cotisation(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period): if period.start.month == 12: - return 2 * person('cotisation', period.last_month) - else: - return person.empty_array() + 1 + return 2 * self("cotisation", period.last_month) + return self.empty_array() + 1 -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables( variable1, variable2, @@ -112,35 +112,38 @@ def add_variables_to_tax_benefit_system(tax_benefit_system): variable6, variable7, cotisation, - ) + ) -def test_pure_cycle(simulation, reference_period): +def test_pure_cycle(simulation, reference_period) -> None: with pytest.raises(CycleError): - simulation.calculate('variable1', period = reference_period) + simulation.calculate("variable1", period=reference_period) -def test_spirals_result_in_default_value(simulation, reference_period): - variable3 = simulation.calculate('variable3', period = reference_period) +def test_spirals_result_in_default_value(simulation, reference_period) -> None: + variable3 = simulation.calculate("variable3", period=reference_period) tools.assert_near(variable3, [0]) -def test_spiral_heuristic(simulation, reference_period): - variable5 = simulation.calculate('variable5', period = reference_period) - variable6 = simulation.calculate('variable6', period = reference_period) - variable6_last_month = simulation.calculate('variable6', reference_period.last_month) +def test_spiral_heuristic(simulation, reference_period) -> None: + variable5 = simulation.calculate("variable5", period=reference_period) + variable6 = simulation.calculate("variable6", period=reference_period) + variable6_last_month = simulation.calculate( + "variable6", + reference_period.last_month, + ) tools.assert_near(variable5, [11]) tools.assert_near(variable6, [11]) tools.assert_near(variable6_last_month, [11]) -def test_spiral_cache(simulation, reference_period): - simulation.calculate('variable7', period = reference_period) - cached_variable7 = simulation.get_holder('variable7').get_array(reference_period) +def test_spiral_cache(simulation, reference_period) -> None: + simulation.calculate("variable7", period=reference_period) + cached_variable7 = simulation.get_holder("variable7").get_array(reference_period) assert cached_variable7 is not None -def test_cotisation_1_level(simulation, reference_period): +def test_cotisation_1_level(simulation, reference_period) -> None: month = reference_period.last_month - cotisation = simulation.calculate('cotisation', period = month) + cotisation = simulation.calculate("cotisation", period=month) tools.assert_near(cotisation, [0]) diff --git a/tests/core/test_dump_restore.py b/tests/core/test_dump_restore.py index 5d377913c9..c84044165c 100644 --- a/tests/core/test_dump_restore.py +++ b/tests/core/test_dump_restore.py @@ -9,10 +9,13 @@ from openfisca_core.tools import simulation_dumper -def test_dump(tax_benefit_system): - directory = tempfile.mkdtemp(prefix = "openfisca_") - simulation = SimulationBuilder().build_from_entities(tax_benefit_system, situation_examples.couple) - calculated_value = simulation.calculate('disposable_income', '2018-01') +def test_dump(tax_benefit_system) -> None: + directory = tempfile.mkdtemp(prefix="openfisca_") + simulation = SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.couple, + ) + calculated_value = simulation.calculate("disposable_income", "2018-01") simulation_dumper.dump_simulation(simulation, directory) simulation_2 = simulation_dumper.restore_simulation(directory, tax_benefit_system) @@ -23,14 +26,23 @@ def test_dump(tax_benefit_system): testing.assert_array_equal(simulation.person.count, simulation_2.person.count) testing.assert_array_equal(simulation.household.ids, simulation_2.household.ids) testing.assert_array_equal(simulation.household.count, simulation_2.household.count) - testing.assert_array_equal(simulation.household.members_position, simulation_2.household.members_position) - testing.assert_array_equal(simulation.household.members_entity_id, simulation_2.household.members_entity_id) - testing.assert_array_equal(simulation.household.members_role, simulation_2.household.members_role) + testing.assert_array_equal( + simulation.household.members_position, + simulation_2.household.members_position, + ) + testing.assert_array_equal( + simulation.household.members_entity_id, + simulation_2.household.members_entity_id, + ) + testing.assert_array_equal( + simulation.household.members_role, + simulation_2.household.members_role, + ) # Check calculated values are in cache - disposable_income_holder = simulation_2.person.get_holder('disposable_income') - cached_value = disposable_income_holder.get_array('2018-01') + disposable_income_holder = simulation_2.person.get_holder("disposable_income") + cached_value = disposable_income_holder.get_array("2018-01") assert cached_value is not None testing.assert_array_equal(cached_value, calculated_value) diff --git a/tests/core/test_entities.py b/tests/core/test_entities.py index b15653b055..aba17dc4dc 100644 --- a/tests/core/test_entities.py +++ b/tests/core/test_entities.py @@ -7,17 +7,17 @@ from openfisca_core.tools import test_runner TEST_CASE = { - 'persons': {'ind0': {}, 'ind1': {}, 'ind2': {}, 'ind3': {}, 'ind4': {}, 'ind5': {}}, - 'households': { - 'h1': {'children': ['ind2', 'ind3'], 'parents': ['ind0', 'ind1']}, - 'h2': {'children': ['ind5'], 'parents': ['ind4']} - }, - } + "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}, "ind4": {}, "ind5": {}}, + "households": { + "h1": {"children": ["ind2", "ind3"], "parents": ["ind0", "ind1"]}, + "h2": {"children": ["ind5"], "parents": ["ind4"]}, + }, +} TEST_CASE_AGES = deepcopy(TEST_CASE) AGES = [40, 37, 7, 9, 54, 20] -for (individu, age) in zip(TEST_CASE_AGES['persons'].values(), AGES): - individu['age'] = age +for individu, age in zip(TEST_CASE_AGES["persons"].values(), AGES): + individu["age"] = age FIRST_PARENT = entities.Household.FIRST_PARENT SECOND_PARENT = entities.Household.SECOND_PARENT @@ -28,22 +28,25 @@ MONTH = "2016-01" -def new_simulation(tax_benefit_system, test_case, period = MONTH): +def new_simulation(tax_benefit_system, test_case, period=MONTH): simulation_builder = SimulationBuilder() simulation_builder.set_default_period(period) return simulation_builder.build_from_entities(tax_benefit_system, test_case) -def test_role_index_and_positions(tax_benefit_system): +def test_role_index_and_positions(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) tools.assert_near(simulation.household.members_entity_id, [0, 0, 0, 0, 1, 1]) - assert((simulation.household.members_role == [FIRST_PARENT, SECOND_PARENT, CHILD, CHILD, FIRST_PARENT, CHILD]).all()) + assert ( + simulation.household.members_role + == [FIRST_PARENT, SECOND_PARENT, CHILD, CHILD, FIRST_PARENT, CHILD] + ).all() tools.assert_near(simulation.household.members_position, [0, 1, 2, 3, 0, 1]) - assert(simulation.person.ids == ["ind0", "ind1", "ind2", "ind3", "ind4", "ind5"]) - assert(simulation.household.ids == ['h1', 'h2']) + assert simulation.person.ids == ["ind0", "ind1", "ind2", "ind3", "ind4", "ind5"] + assert simulation.household.ids == ["h1", "h2"] -def test_entity_structure_with_constructor(tax_benefit_system): +def test_entity_structure_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -64,16 +67,22 @@ def test_entity_structure_with_constructor(tax_benefit_system): - claudia """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) household = simulation.household tools.assert_near(household.members_entity_id, [0, 0, 1, 0, 0]) - assert((household.members_role == [FIRST_PARENT, SECOND_PARENT, FIRST_PARENT, CHILD, CHILD]).all()) + assert ( + household.members_role + == [FIRST_PARENT, SECOND_PARENT, FIRST_PARENT, CHILD, CHILD] + ).all() tools.assert_near(household.members_position, [0, 1, 0, 2, 3]) -def test_entity_variables_with_constructor(tax_benefit_system): +def test_entity_variables_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: {} @@ -98,12 +107,15 @@ def test_entity_variables_with_constructor(tax_benefit_system): 2017-06: 600 """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) household = simulation.household - tools.assert_near(household('rent', "2017-06"), [800, 600]) + tools.assert_near(household("rent", "2017-06"), [800, 600]) -def test_person_variable_with_constructor(tax_benefit_system): +def test_person_variable_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -131,13 +143,16 @@ def test_person_variable_with_constructor(tax_benefit_system): - claudia """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) person = simulation.person - tools.assert_near(person('salary', "2017-11"), [1500, 0, 3000, 0, 0]) - tools.assert_near(person('salary', "2017-12"), [2000, 0, 4000, 0, 0]) + tools.assert_near(person("salary", "2017-11"), [1500, 0, 3000, 0, 0]) + tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) -def test_set_input_with_constructor(tax_benefit_system): +def test_set_input_with_constructor(tax_benefit_system) -> None: simulation_yaml = """ persons: bill: @@ -170,136 +185,148 @@ def test_set_input_with_constructor(tax_benefit_system): - claudia """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(simulation_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(simulation_yaml), + ) person = simulation.person - tools.assert_near(person('salary', "2017-12"), [2000, 0, 4000, 0, 0]) - tools.assert_near(person('salary', "2017-10"), [2000, 3000, 1600, 0, 0]) + tools.assert_near(person("salary", "2017-12"), [2000, 0, 4000, 0, 0]) + tools.assert_near(person("salary", "2017-10"), [2000, 3000, 1600, 0, 0]) -def test_has_role(tax_benefit_system): +def test_has_role(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons tools.assert_near(individu.has_role(CHILD), [False, False, True, True, False, True]) -def test_has_role_with_subrole(tax_benefit_system): +def test_has_role_with_subrole(tax_benefit_system) -> None: simulation = new_simulation(tax_benefit_system, TEST_CASE) individu = simulation.persons - tools.assert_near(individu.has_role(PARENT), [True, True, False, False, True, False]) - tools.assert_near(individu.has_role(FIRST_PARENT), [True, False, False, False, True, False]) - tools.assert_near(individu.has_role(SECOND_PARENT), [False, True, False, False, False, False]) - - -def test_project(tax_benefit_system): + tools.assert_near( + individu.has_role(PARENT), + [True, True, False, False, True, False], + ) + tools.assert_near( + individu.has_role(FIRST_PARENT), + [True, False, False, False, True, False], + ) + tools.assert_near( + individu.has_role(SECOND_PARENT), + [False, True, False, False, False, False], + ) + + +def test_project(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['households']['h1']['housing_tax'] = 20000 + test_case["households"]["h1"]["housing_tax"] = 20000 simulation = new_simulation(tax_benefit_system, test_case, YEAR) household = simulation.household - housing_tax = household('housing_tax', YEAR) + housing_tax = household("housing_tax", YEAR) projected_housing_tax = household.project(housing_tax) tools.assert_near(projected_housing_tax, [20000, 20000, 20000, 20000, 0, 0]) - housing_tax_projected_on_parents = household.project(housing_tax, role = PARENT) + housing_tax_projected_on_parents = household.project(housing_tax, role=PARENT) tools.assert_near(housing_tax_projected_on_parents, [20000, 20000, 0, 0, 0, 0]) -def test_implicit_projection(tax_benefit_system): +def test_implicit_projection(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['households']['h1']['housing_tax'] = 20000 + test_case["households"]["h1"]["housing_tax"] = 20000 simulation = new_simulation(tax_benefit_system, test_case, YEAR) individu = simulation.person - housing_tax = individu.household('housing_tax', YEAR) + housing_tax = individu.household("housing_tax", YEAR) tools.assert_near(housing_tax, [20000, 20000, 20000, 20000, 0, 0]) -def test_sum(tax_benefit_system): +def test_sum(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 simulation = new_simulation(tax_benefit_system, test_case, MONTH) household = simulation.household - salary = household.members('salary', "2016-01") + salary = household.members("salary", "2016-01") total_salary_by_household = household.sum(salary) tools.assert_near(total_salary_by_household, [2500, 3500]) - total_salary_parents_by_household = household.sum(salary, role = PARENT) + total_salary_parents_by_household = household.sum(salary, role=PARENT) tools.assert_near(total_salary_parents_by_household, [2500, 3000]) -def test_any(tax_benefit_system): +def test_any(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) - condition_age = (age <= 18) + age = household.members("age", period=MONTH) + condition_age = age <= 18 has_household_member_with_age_inf_18 = household.any(condition_age) tools.assert_near(has_household_member_with_age_inf_18, [True, False]) - condition_age_2 = (age > 18) - has_household_CHILD_with_age_sup_18 = household.any(condition_age_2, role = CHILD) + condition_age_2 = age > 18 + has_household_CHILD_with_age_sup_18 = household.any(condition_age_2, role=CHILD) tools.assert_near(has_household_CHILD_with_age_sup_18, [False, True]) -def test_all(tax_benefit_system): +def test_all(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) + age = household.members("age", period=MONTH) - condition_age = (age >= 18) + condition_age = age >= 18 all_persons_age_sup_18 = household.all(condition_age) tools.assert_near(all_persons_age_sup_18, [False, True]) - all_parents_age_sup_18 = household.all(condition_age, role = PARENT) + all_parents_age_sup_18 = household.all(condition_age, role=PARENT) tools.assert_near(all_parents_age_sup_18, [True, True]) -def test_max(tax_benefit_system): +def test_max(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) + age = household.members("age", period=MONTH) age_max = household.max(age) tools.assert_near(age_max, [40, 54]) - age_max_child = household.max(age, role = CHILD) + age_max_child = household.max(age, role=CHILD) tools.assert_near(age_max_child, [9, 20]) -def test_min(tax_benefit_system): +def test_min(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - age = household.members('age', period = MONTH) + age = household.members("age", period=MONTH) age_min = household.min(age) tools.assert_near(age_min, [7, 20]) - age_min_parents = household.min(age, role = PARENT) + age_min_parents = household.min(age, role=PARENT) tools.assert_near(age_min_parents, [37, 54]) -def test_value_nth_person(tax_benefit_system): +def test_value_nth_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - array = household.members('age', MONTH) + array = household.members("age", MONTH) result0 = household.value_nth_person(0, array, default=-1) tools.assert_near(result0, [40, 54]) @@ -314,141 +341,157 @@ def test_value_nth_person(tax_benefit_system): tools.assert_near(result3, [9, -1]) -def test_rank(tax_benefit_system): +def test_rank(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE_AGES) simulation = new_simulation(tax_benefit_system, test_case) person = simulation.person - age = person('age', MONTH) # [40, 37, 7, 9, 54, 20] + age = person("age", MONTH) # [40, 37, 7, 9, 54, 20] rank = person.get_rank(person.household, age) tools.assert_near(rank, [3, 2, 0, 1, 1, 0]) - rank_in_siblings = person.get_rank(person.household, - age, condition = person.has_role(entities.Household.CHILD)) + rank_in_siblings = person.get_rank( + person.household, + -age, + condition=person.has_role(entities.Household.CHILD), + ) tools.assert_near(rank_in_siblings, [-1, -1, 1, 0, -1, 0]) -def test_partner(tax_benefit_system): +def test_partner(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 simulation = new_simulation(tax_benefit_system, test_case) persons = simulation.persons - salary = persons('salary', period = MONTH) + salary = persons("salary", period=MONTH) salary_second_parent = persons.value_from_partner(salary, persons.household, PARENT) tools.assert_near(salary_second_parent, [1500, 1000, 0, 0, 0, 0]) -def test_value_from_first_person(tax_benefit_system): +def test_value_from_first_person(tax_benefit_system) -> None: test_case = deepcopy(TEST_CASE) - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - salaries = household.members('salary', period = MONTH) + salaries = household.members("salary", period=MONTH) salary_first_person = household.value_from_first_person(salaries) tools.assert_near(salary_first_person, [1000, 3000]) -def test_projectors_methods(tax_benefit_system): - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, situation_examples.couple) +def test_projectors_methods(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + situation_examples.couple, + ) household = simulation.household person = simulation.person projected_vector = household.first_parent.has_role(entities.Household.FIRST_PARENT) - assert(len(projected_vector) == 1) # Must be of a household dimension + assert len(projected_vector) == 1 # Must be of a household dimension - salary_i = person.household.members('salary', '2017-01') - assert(len(person.household.sum(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.max(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.min(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.all(salary_i)) == 2) # Must be of a person dimension - assert(len(person.household.any(salary_i)) == 2) # Must be of a person dimension - assert(len(household.first_parent.get_rank(household, salary_i)) == 1) # Must be of a person dimension + salary_i = person.household.members("salary", "2017-01") + assert len(person.household.sum(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.max(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.min(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.all(salary_i)) == 2 # Must be of a person dimension + assert len(person.household.any(salary_i)) == 2 # Must be of a person dimension + assert ( + len(household.first_parent.get_rank(household, salary_i)) == 1 + ) # Must be of a person dimension -def test_sum_following_bug_ipp_1(tax_benefit_system): +def test_sum_following_bug_ipp_1(tax_benefit_system) -> None: test_case = { - 'persons': {'ind0': {}, 'ind1': {}, 'ind2': {}, 'ind3': {}}, - 'households': { - 'h1': {'parents': ['ind0']}, - 'h2': {'parents': ['ind1'], 'children': ['ind2', 'ind3']} - }, - } - test_case['persons']['ind0']['salary'] = 2000 - test_case['persons']['ind1']['salary'] = 2000 - test_case['persons']['ind2']['salary'] = 1000 - test_case['persons']['ind3']['salary'] = 1000 + "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, + "households": { + "h1": {"parents": ["ind0"]}, + "h2": {"parents": ["ind1"], "children": ["ind2", "ind3"]}, + }, + } + test_case["persons"]["ind0"]["salary"] = 2000 + test_case["persons"]["ind1"]["salary"] = 2000 + test_case["persons"]["ind2"]["salary"] = 1000 + test_case["persons"]["ind3"]["salary"] = 1000 simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - eligible_i = household.members('salary', period = MONTH) < 1500 - nb_eligibles_by_household = household.sum(eligible_i, role = CHILD) + eligible_i = household.members("salary", period=MONTH) < 1500 + nb_eligibles_by_household = household.sum(eligible_i, role=CHILD) tools.assert_near(nb_eligibles_by_household, [0, 2]) -def test_sum_following_bug_ipp_2(tax_benefit_system): +def test_sum_following_bug_ipp_2(tax_benefit_system) -> None: test_case = { - 'persons': {'ind0': {}, 'ind1': {}, 'ind2': {}, 'ind3': {}}, - 'households': { - 'h1': {'parents': ['ind1'], 'children': ['ind2', 'ind3']}, - 'h2': {'parents': ['ind0']}, - }, - } - test_case['persons']['ind0']['salary'] = 2000 - test_case['persons']['ind1']['salary'] = 2000 - test_case['persons']['ind2']['salary'] = 1000 - test_case['persons']['ind3']['salary'] = 1000 + "persons": {"ind0": {}, "ind1": {}, "ind2": {}, "ind3": {}}, + "households": { + "h1": {"parents": ["ind1"], "children": ["ind2", "ind3"]}, + "h2": {"parents": ["ind0"]}, + }, + } + test_case["persons"]["ind0"]["salary"] = 2000 + test_case["persons"]["ind1"]["salary"] = 2000 + test_case["persons"]["ind2"]["salary"] = 1000 + test_case["persons"]["ind3"]["salary"] = 1000 simulation = new_simulation(tax_benefit_system, test_case) household = simulation.household - eligible_i = household.members('salary', period = MONTH) < 1500 - nb_eligibles_by_household = household.sum(eligible_i, role = CHILD) + eligible_i = household.members("salary", period=MONTH) < 1500 + nb_eligibles_by_household = household.sum(eligible_i, role=CHILD) tools.assert_near(nb_eligibles_by_household, [2, 0]) -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: test_case = deepcopy(situation_examples.single) test_case["persons"]["Alicia"]["salary"] = {"2017-01": 0} simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_case) - simulation.calculate('disposable_income', '2017-01') - memory_usage = simulation.person.get_memory_usage(variables = ['salary']) - assert(memory_usage['total_nb_bytes'] > 0) - assert(len(memory_usage['by_variable']) == 1) + simulation.calculate("disposable_income", "2017-01") + memory_usage = simulation.person.get_memory_usage(variables=["salary"]) + assert memory_usage["total_nb_bytes"] > 0 + assert len(memory_usage["by_variable"]) == 1 -def test_unordered_persons(tax_benefit_system): +def test_unordered_persons(tax_benefit_system) -> None: test_case = { - 'persons': {'ind4': {}, 'ind3': {}, 'ind1': {}, 'ind2': {}, 'ind5': {}, 'ind0': {}}, - 'households': { - 'h1': {'children': ['ind2', 'ind3'], 'parents': ['ind0', 'ind1']}, - 'h2': {'children': ['ind5'], 'parents': ['ind4']} - }, - } + "persons": { + "ind4": {}, + "ind3": {}, + "ind1": {}, + "ind2": {}, + "ind5": {}, + "ind0": {}, + }, + "households": { + "h1": {"children": ["ind2", "ind3"], "parents": ["ind0", "ind1"]}, + "h2": {"children": ["ind5"], "parents": ["ind4"]}, + }, + } # 1st family - test_case['persons']['ind0']['salary'] = 1000 - test_case['persons']['ind1']['salary'] = 1500 - test_case['persons']['ind2']['salary'] = 20 - test_case['households']['h1']['accommodation_size'] = 160 + test_case["persons"]["ind0"]["salary"] = 1000 + test_case["persons"]["ind1"]["salary"] = 1500 + test_case["persons"]["ind2"]["salary"] = 20 + test_case["households"]["h1"]["accommodation_size"] = 160 # 2nd family - test_case['persons']['ind4']['salary'] = 3000 - test_case['persons']['ind5']['salary'] = 500 - test_case['households']['h2']['accommodation_size'] = 60 + test_case["persons"]["ind4"]["salary"] = 3000 + test_case["persons"]["ind5"]["salary"] = 500 + test_case["households"]["h2"]["accommodation_size"] = 60 # household.members_entity_id == [1, 0, 0, 0, 1, 0] @@ -456,8 +499,8 @@ def test_unordered_persons(tax_benefit_system): household = simulation.household person = simulation.person - salary = household.members('salary', "2016-01") # [ 3000, 0, 1500, 20, 500, 1000 ] - accommodation_size = household('accommodation_size', "2016-01") # [ 160, 60 ] + salary = household.members("salary", "2016-01") # [ 3000, 0, 1500, 20, 500, 1000 ] + accommodation_size = household("accommodation_size", "2016-01") # [ 160, 60 ] # Aggregation/Projection persons -> entity @@ -466,30 +509,42 @@ def test_unordered_persons(tax_benefit_system): tools.assert_near(household.min(salary), [0, 500]) tools.assert_near(household.all(salary > 0), [False, True]) tools.assert_near(household.any(salary > 2000), [False, True]) - tools.assert_near(household.first_person('salary', "2016-01"), [0, 3000]) - tools.assert_near(household.first_parent('salary', "2016-01"), [1000, 3000]) - tools.assert_near(household.second_parent('salary', "2016-01"), [1500, 0]) - tools.assert_near(person.value_from_partner(salary, person.household, PARENT), [0, 0, 1000, 0, 0, 1500]) - - tools.assert_near(household.sum(salary, role = PARENT), [2500, 3000]) - tools.assert_near(household.sum(salary, role = CHILD), [20, 500]) - tools.assert_near(household.max(salary, role = PARENT), [1500, 3000]) - tools.assert_near(household.max(salary, role = CHILD), [20, 500]) - tools.assert_near(household.min(salary, role = PARENT), [1000, 3000]) - tools.assert_near(household.min(salary, role = CHILD), [0, 500]) - tools.assert_near(household.all(salary > 0, role = PARENT), [True, True]) - tools.assert_near(household.all(salary > 0, role = CHILD), [False, True]) - tools.assert_near(household.any(salary < 1500, role = PARENT), [True, False]) - tools.assert_near(household.any(salary > 200, role = CHILD), [False, True]) + tools.assert_near(household.first_person("salary", "2016-01"), [0, 3000]) + tools.assert_near(household.first_parent("salary", "2016-01"), [1000, 3000]) + tools.assert_near(household.second_parent("salary", "2016-01"), [1500, 0]) + tools.assert_near( + person.value_from_partner(salary, person.household, PARENT), + [0, 0, 1000, 0, 0, 1500], + ) + + tools.assert_near(household.sum(salary, role=PARENT), [2500, 3000]) + tools.assert_near(household.sum(salary, role=CHILD), [20, 500]) + tools.assert_near(household.max(salary, role=PARENT), [1500, 3000]) + tools.assert_near(household.max(salary, role=CHILD), [20, 500]) + tools.assert_near(household.min(salary, role=PARENT), [1000, 3000]) + tools.assert_near(household.min(salary, role=CHILD), [0, 500]) + tools.assert_near(household.all(salary > 0, role=PARENT), [True, True]) + tools.assert_near(household.all(salary > 0, role=CHILD), [False, True]) + tools.assert_near(household.any(salary < 1500, role=PARENT), [True, False]) + tools.assert_near(household.any(salary > 200, role=CHILD), [False, True]) # nb_persons tools.assert_near(household.nb_persons(), [4, 2]) - tools.assert_near(household.nb_persons(role = PARENT), [2, 1]) - tools.assert_near(household.nb_persons(role = CHILD), [2, 1]) + tools.assert_near(household.nb_persons(role=PARENT), [2, 1]) + tools.assert_near(household.nb_persons(role=CHILD), [2, 1]) # Projection entity -> persons - tools.assert_near(household.project(accommodation_size), [60, 160, 160, 160, 60, 160]) - tools.assert_near(household.project(accommodation_size, role = PARENT), [60, 0, 160, 0, 0, 160]) - tools.assert_near(household.project(accommodation_size, role = CHILD), [0, 160, 0, 160, 60, 0]) + tools.assert_near( + household.project(accommodation_size), + [60, 160, 160, 160, 60, 160], + ) + tools.assert_near( + household.project(accommodation_size, role=PARENT), + [60, 0, 160, 0, 0, 160], + ) + tools.assert_near( + household.project(accommodation_size, role=CHILD), + [0, 160, 0, 160, 60, 0], + ) diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 5c3da81d66..4854815ac3 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -1,24 +1,26 @@ import pytest -def test_load_extension(tax_benefit_system): +def test_load_extension(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() - assert tbs.get_variable('local_town_child_allowance') is None + assert tbs.get_variable("local_town_child_allowance") is None - tbs.load_extension('openfisca_extension_template') + tbs.load_extension("openfisca_extension_template") - assert tbs.get_variable('local_town_child_allowance') is not None - assert tax_benefit_system.get_variable('local_town_child_allowance') is None + assert tbs.get_variable("local_town_child_allowance") is not None + assert tax_benefit_system.get_variable("local_town_child_allowance") is None -def test_access_to_parameters(tax_benefit_system): +def test_access_to_parameters(tax_benefit_system) -> None: tbs = tax_benefit_system.clone() - tbs.load_extension('openfisca_extension_template') + tbs.load_extension("openfisca_extension_template") - assert tbs.parameters('2016-01').local_town.child_allowance.amount == 100.0 - assert tbs.parameters.local_town.child_allowance.amount('2016-01') == 100.0 + assert tbs.parameters("2016-01").local_town.child_allowance.amount == 100.0 + assert tbs.parameters.local_town.child_allowance.amount("2016-01") == 100.0 -def test_failure_to_load_extension_when_directory_doesnt_exist(tax_benefit_system): +def test_failure_to_load_extension_when_directory_doesnt_exist( + tax_benefit_system, +) -> None: with pytest.raises(ValueError): - tax_benefit_system.load_extension('/this/is/not/a/real/path') + tax_benefit_system.load_extension("/this/is/not/a/real/path") diff --git a/tests/core/test_formulas.py b/tests/core/test_formulas.py index 876ca239d1..32e6fd35e7 100644 --- a/tests/core/test_formulas.py +++ b/tests/core/test_formulas.py @@ -1,96 +1,187 @@ import numpy +from pytest import approx, fixture from openfisca_country_template import entities -from openfisca_core import commons, periods +from openfisca_core import commons +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder from openfisca_core.variables import Variable -from pytest import fixture, approx - class choice(Variable): value_type = int entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH class uses_multiplication(Variable): value_type = int entity = entities.Person - label = 'Variable with formula that uses multiplication' - definition_period = periods.MONTH + label = "Variable with formula that uses multiplication" + definition_period = DateUnit.MONTH - def formula(person, period): - choice = person('choice', period) - result = (choice == 1) * 80 + (choice == 2) * 90 - return result + def formula(self, period): + choice = self("choice", period) + return (choice == 1) * 80 + (choice == 2) * 90 class returns_scalar(Variable): value_type = int entity = entities.Person - label = 'Variable with formula that returns a scalar value' - definition_period = periods.MONTH + label = "Variable with formula that returns a scalar value" + definition_period = DateUnit.MONTH - def formula(person, period): + def formula(self, period) -> int: return 666 class uses_switch(Variable): value_type = int entity = entities.Person - label = 'Variable with formula that uses switch' - definition_period = periods.MONTH + label = "Variable with formula that uses switch" + definition_period = DateUnit.MONTH - def formula(person, period): - choice = person('choice', period) - result = commons.switch( + def formula(self, period): + choice = self("choice", period) + return commons.switch( choice, { 1: 80, 2: 90, - }, - ) - return result + }, + ) -@fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): - tax_benefit_system.add_variables(choice, uses_multiplication, uses_switch, returns_scalar) +@fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: + tax_benefit_system.add_variables( + choice, + uses_multiplication, + uses_switch, + returns_scalar, + ) @fixture -def month(): - return '2013-01' +def month() -> str: + return "2013-01" @fixture def simulation(tax_benefit_system, month): simulation_builder = SimulationBuilder() simulation_builder.default_period = month - simulation = simulation_builder.build_from_variables(tax_benefit_system, {'choice': numpy.random.randint(2, size = 1000) + 1}) + simulation = simulation_builder.build_from_variables( + tax_benefit_system, + {"choice": numpy.random.randint(2, size=1000) + 1}, + ) simulation.debug = True return simulation -def test_switch(simulation, month): - uses_switch = simulation.calculate('uses_switch', period = month) +def test_switch(simulation, month) -> None: + uses_switch = simulation.calculate("uses_switch", period=month) assert isinstance(uses_switch, numpy.ndarray) -def test_multiplication(simulation, month): - uses_multiplication = simulation.calculate('uses_multiplication', period = month) +def test_multiplication(simulation, month) -> None: + uses_multiplication = simulation.calculate("uses_multiplication", period=month) assert isinstance(uses_multiplication, numpy.ndarray) -def test_broadcast_scalar(simulation, month): - array_value = simulation.calculate('returns_scalar', period = month) +def test_broadcast_scalar(simulation, month) -> None: + array_value = simulation.calculate("returns_scalar", period=month) assert isinstance(array_value, numpy.ndarray) assert array_value == approx(numpy.repeat(666, 1000)) -def test_compare_multiplication_and_switch(simulation, month): - uses_multiplication = simulation.calculate('uses_multiplication', period = month) - uses_switch = simulation.calculate('uses_switch', period = month) +def test_compare_multiplication_and_switch(simulation, month) -> None: + uses_multiplication = simulation.calculate("uses_multiplication", period=month) + uses_switch = simulation.calculate("uses_switch", period=month) assert numpy.all(uses_switch == uses_multiplication) + + +def test_group_encapsulation() -> None: + """Projects a calculation to all members of an entity. + + When a household contains more than one family + Variables can be defined for the the household + And calculations are projected to all the member families. + + """ + from openfisca_core.entities import build_entity + from openfisca_core.periods import DateUnit + from openfisca_core.taxbenefitsystems import TaxBenefitSystem + + person_entity = build_entity( + key="person", + plural="people", + label="A person", + is_person=True, + ) + family_entity = build_entity( + key="family", + plural="families", + label="A family (all members in the same household)", + containing_entities=["household"], + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + household_entity = build_entity( + key="household", + plural="households", + label="A household, containing one or more families", + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + + entities = [person_entity, family_entity, household_entity] + + system = TaxBenefitSystem(entities) + + class household_level_variable(Variable): + value_type = int + entity = household_entity + definition_period = DateUnit.ETERNITY + + class projected_family_level_variable(Variable): + value_type = int + entity = family_entity + definition_period = DateUnit.ETERNITY + + def formula(self, period): + return self.household("household_level_variable", period) + + system.add_variables(household_level_variable, projected_family_level_variable) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": {"person1": {}, "person2": {}, "person3": {}}, + "families": { + "family1": {"members": ["person1", "person2"]}, + "family2": {"members": ["person3"]}, + }, + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + "household_level_variable": {"eternity": 5}, + }, + }, + }, + ) + + assert ( + simulation.calculate("projected_family_level_variable", "2021-01-01") == 5 + ).all() diff --git a/tests/core/test_holders.py b/tests/core/test_holders.py index cd26231037..c72d053ad6 100644 --- a/tests/core/test_holders.py +++ b/tests/core/test_holders.py @@ -1,185 +1,218 @@ -import pytest - import numpy +import pytest from openfisca_country_template import situation_examples from openfisca_country_template.variables import housing from openfisca_core import holders, periods, tools from openfisca_core.errors import PeriodMismatchError +from openfisca_core.holders import Holder from openfisca_core.memory_config import MemoryConfig +from openfisca_core.periods import DateUnit from openfisca_core.simulations import SimulationBuilder -from openfisca_core.holders import Holder @pytest.fixture def single(tax_benefit_system): - return \ - SimulationBuilder() \ - .build_from_entities(tax_benefit_system, situation_examples.single) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.single, + ) @pytest.fixture def couple(tax_benefit_system): - return \ - SimulationBuilder(). \ - build_from_entities(tax_benefit_system, situation_examples.couple) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + situation_examples.couple, + ) -period = periods.period('2017-12') +period = periods.period("2017-12") -def test_set_input_enum_string(couple): +def test_set_input_enum_string(couple) -> None: simulation = couple - status_occupancy = numpy.asarray(['free_lodger']) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + status_occupancy = numpy.asarray(["free_lodger"]) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_int(couple): +def test_set_input_enum_int(couple) -> None: simulation = couple - status_occupancy = numpy.asarray([2], dtype = numpy.int16) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + status_occupancy = numpy.asarray([2], dtype=numpy.int16) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_set_input_enum_item(couple): +def test_set_input_enum_item(couple) -> None: simulation = couple status_occupancy = numpy.asarray([housing.HousingOccupancyStatus.free_lodger]) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result == housing.HousingOccupancyStatus.free_lodger -def test_yearly_input_month_variable(couple): +def test_yearly_input_month_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: - couple.set_input('rent', 2019, 3000) - assert 'Unable to set a value for variable "rent" for year-long period' in error.value.message + couple.set_input("rent", 2019, 3000) + assert ( + 'Unable to set a value for variable "rent" for year-long period' + in error.value.message + ) -def test_3_months_input_month_variable(couple): +def test_3_months_input_month_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: - couple.set_input('rent', 'month:2019-01:3', 3000) - assert 'Unable to set a value for variable "rent" for 3-months-long period' in error.value.message + couple.set_input("rent", "month:2019-01:3", 3000) + assert ( + 'Unable to set a value for variable "rent" for 3-months-long period' + in error.value.message + ) -def test_month_input_year_variable(couple): +def test_month_input_year_variable(couple) -> None: with pytest.raises(PeriodMismatchError) as error: - couple.set_input('housing_tax', '2019-01', 3000) - assert 'Unable to set a value for variable "housing_tax" for month-long period' in error.value.message + couple.set_input("housing_tax", "2019-01", 3000) + assert ( + 'Unable to set a value for variable "housing_tax" for month-long period' + in error.value.message + ) -def test_enum_dtype(couple): +def test_enum_dtype(couple) -> None: simulation = couple - status_occupancy = numpy.asarray([2], dtype = numpy.int16) - simulation.household.get_holder('housing_occupancy_status').set_input(period, status_occupancy) - result = simulation.calculate('housing_occupancy_status', period) + status_occupancy = numpy.asarray([2], dtype=numpy.int16) + simulation.household.get_holder("housing_occupancy_status").set_input( + period, + status_occupancy, + ) + result = simulation.calculate("housing_occupancy_status", period) assert result.dtype.kind is not None -def test_permanent_variable_empty(single): +def test_permanent_variable_empty(single) -> None: simulation = single - holder = simulation.person.get_holder('birth') + holder = simulation.person.get_holder("birth") assert holder.get_array(None) is None -def test_permanent_variable_filled(single): +def test_permanent_variable_filled(single) -> None: simulation = single - holder = simulation.person.get_holder('birth') - value = numpy.asarray(['1980-01-01'], dtype = holder.variable.dtype) - holder.set_input(periods.period(periods.ETERNITY), value) + holder = simulation.person.get_holder("birth") + value = numpy.asarray(["1980-01-01"], dtype=holder.variable.dtype) + holder.set_input(periods.period(DateUnit.ETERNITY), value) assert holder.get_array(None) == value - assert holder.get_array(periods.ETERNITY) == value - assert holder.get_array('2016-01') == value + assert holder.get_array(DateUnit.ETERNITY) == value + assert holder.get_array("2016-01") == value -def test_delete_arrays(single): +def test_delete_arrays(single) -> None: simulation = single - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) salary_holder.set_input(periods.period(2018), numpy.asarray([60000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 5000 - salary_holder.delete_arrays(period = 2018) + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 5000 + salary_holder.delete_arrays(period=2018) + + salary_array = simulation.get_array("salary", "2017-01") + assert salary_array is not None + salary_array = simulation.get_array("salary", "2018-01") + assert salary_array is None + salary_holder.set_input(periods.period(2018), numpy.asarray([15000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 1250 + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 1250 -def test_get_memory_usage(single): +def test_get_memory_usage(single) -> None: simulation = single - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") memory_usage = salary_holder.get_memory_usage() - assert memory_usage['total_nb_bytes'] == 0 + assert memory_usage["total_nb_bytes"] == 0 salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) memory_usage = salary_holder.get_memory_usage() - assert memory_usage['nb_cells_by_array'] == 1 - assert memory_usage['cell_size'] == 4 # float 32 - assert memory_usage['nb_cells_by_array'] == 1 # one person - assert memory_usage['nb_arrays'] == 12 # 12 months - assert memory_usage['total_nb_bytes'] == 4 * 12 * 1 + assert memory_usage["nb_cells_by_array"] == 1 + assert memory_usage["cell_size"] == 4 # float 32 + assert memory_usage["nb_cells_by_array"] == 1 # one person + assert memory_usage["nb_arrays"] == 12 # 12 months + assert memory_usage["total_nb_bytes"] == 4 * 12 * 1 -def test_get_memory_usage_with_trace(single): +def test_get_memory_usage_with_trace(single) -> None: simulation = single simulation.trace = True - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) - simulation.calculate('salary', '2017-01') - simulation.calculate('salary', '2017-01') - simulation.calculate('salary', '2017-02') - simulation.calculate_add('salary', '2017') # 12 calculations + simulation.calculate("salary", "2017-01") + simulation.calculate("salary", "2017-01") + simulation.calculate("salary", "2017-02") + simulation.calculate_add("salary", "2017") # 12 calculations memory_usage = salary_holder.get_memory_usage() - assert memory_usage['nb_requests'] == 15 - assert memory_usage['nb_requests_by_array'] == 1.25 # 15 calculations / 12 arrays + assert memory_usage["nb_requests"] == 15 + assert memory_usage["nb_requests_by_array"] == 1.25 # 15 calculations / 12 arrays -def test_set_input_dispatch_by_period(single): +def test_set_input_dispatch_by_period(single) -> None: simulation = single - variable = simulation.tax_benefit_system.get_variable('housing_occupancy_status') + variable = simulation.tax_benefit_system.get_variable("housing_occupancy_status") entity = simulation.household holder = Holder(variable, entity) - holders.set_input_dispatch_by_period(holder, periods.period(2019), 'owner') - assert holder.get_array('2019-01') == holder.get_array('2019-12') # Check the feature - assert holder.get_array('2019-01') is holder.get_array('2019-12') # Check that the vectors are the same in memory, to avoid duplication + holders.set_input_dispatch_by_period(holder, periods.period(2019), "owner") + assert holder.get_array("2019-01") == holder.get_array( + "2019-12", + ) # Check the feature + assert holder.get_array("2019-01") is holder.get_array( + "2019-12", + ) # Check that the vectors are the same in memory, to avoid duplication -force_storage_on_disk = MemoryConfig(max_memory_occupation = 0) +force_storage_on_disk = MemoryConfig(max_memory_occupation=0) -def test_delete_arrays_on_disk(single): +def test_delete_arrays_on_disk(single) -> None: simulation = single simulation.memory_config = force_storage_on_disk - salary_holder = simulation.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") salary_holder.set_input(periods.period(2017), numpy.asarray([30000])) salary_holder.set_input(periods.period(2018), numpy.asarray([60000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 5000 - salary_holder.delete_arrays(period = 2018) + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 5000 + salary_holder.delete_arrays(period=2018) salary_holder.set_input(periods.period(2018), numpy.asarray([15000])) - assert simulation.person('salary', '2017-01') == 2500 - assert simulation.person('salary', '2018-01') == 1250 + assert simulation.person("salary", "2017-01") == 2500 + assert simulation.person("salary", "2018-01") == 1250 -def test_cache_disk(couple): +def test_cache_disk(couple) -> None: simulation = couple simulation.memory_config = force_storage_on_disk - month = periods.period('2017-01') - holder = simulation.person.get_holder('disposable_income') + month = periods.period("2017-01") + holder = simulation.person.get_holder("disposable_income") data = numpy.asarray([2000, 3000]) holder.put_in_cache(data, month) stored_data = holder.get_array(month) tools.assert_near(data, stored_data) -def test_known_periods(couple): +def test_known_periods(couple) -> None: simulation = couple simulation.memory_config = force_storage_on_disk - month = periods.period('2017-01') - month_2 = periods.period('2017-02') - holder = simulation.person.get_holder('disposable_income') + month = periods.period("2017-01") + month_2 = periods.period("2017-02") + holder = simulation.person.get_holder("disposable_income") data = numpy.asarray([2000, 3000]) holder.put_in_cache(data, month) holder._memory_storage.put(data, month_2) @@ -187,28 +220,34 @@ def test_known_periods(couple): assert sorted(holder.get_known_periods()), [month == month_2] -def test_cache_enum_on_disk(single): +def test_cache_enum_on_disk(single) -> None: simulation = single simulation.memory_config = force_storage_on_disk - month = periods.period('2017-01') - simulation.calculate('housing_occupancy_status', month) # First calculation - housing_occupancy_status = simulation.calculate('housing_occupancy_status', month) # Read from cache + month = periods.period("2017-01") + simulation.calculate("housing_occupancy_status", month) # First calculation + housing_occupancy_status = simulation.calculate( + "housing_occupancy_status", + month, + ) # Read from cache assert housing_occupancy_status == housing.HousingOccupancyStatus.tenant -def test_set_not_cached_variable(single): - dont_cache_variable = MemoryConfig(max_memory_occupation = 1, variables_to_drop = ['salary']) +def test_set_not_cached_variable(single) -> None: + dont_cache_variable = MemoryConfig( + max_memory_occupation=1, + variables_to_drop=["salary"], + ) simulation = single simulation.memory_config = dont_cache_variable - holder = simulation.person.get_holder('salary') + holder = simulation.person.get_holder("salary") array = numpy.asarray([2000]) - holder.set_input('2015-01', array) - assert simulation.calculate('salary', '2015-01') == array + holder.set_input("2015-01", array) + assert simulation.calculate("salary", "2015-01") == array -def test_set_input_float_to_int(single): +def test_set_input_float_to_int(single) -> None: simulation = single age = numpy.asarray([50.6]) - simulation.person.get_holder('age').set_input(period, age) - result = simulation.calculate('age', period) + simulation.person.get_holder("age").set_input(period, age) + result = simulation.calculate("age", period) assert result == numpy.asarray([50]) diff --git a/tests/core/test_opt_out_cache.py b/tests/core/test_opt_out_cache.py index b4eab3e5a5..2f61da2898 100644 --- a/tests/core/test_opt_out_cache.py +++ b/tests/core/test_opt_out_cache.py @@ -3,10 +3,9 @@ from openfisca_country_template.entities import Person from openfisca_core import periods -from openfisca_core.periods import MONTH +from openfisca_core.periods import DateUnit from openfisca_core.variables import Variable - PERIOD = periods.period("2016-01") @@ -14,57 +13,57 @@ class input(Variable): value_type = int entity = Person label = "Input variable" - definition_period = MONTH + definition_period = DateUnit.MONTH class intermediate(Variable): value_type = int entity = Person label = "Intermediate result that don't need to be cached" - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula(person, period): - return person('input', period) + def formula(self, period): + return self("input", period) class output(Variable): value_type = int entity = Person - label = 'Output variable' - definition_period = MONTH + label = "Output variable" + definition_period = DateUnit.MONTH - def formula(person, period): - return person('intermediate', period) + def formula(self, period): + return self("intermediate", period) -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(input, intermediate, output) -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_cache_blakclist(tax_benefit_system): - tax_benefit_system.cache_blacklist = set(['intermediate']) +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_cache_blakclist(tax_benefit_system) -> None: + tax_benefit_system.cache_blacklist = {"intermediate"} -@pytest.mark.parametrize("simulation", [({'input': 1}, PERIOD)], indirect = True) -def test_without_cache_opt_out(simulation): - simulation.calculate('output', period = PERIOD) - intermediate_cache = simulation.persons.get_holder('intermediate') - assert(intermediate_cache.get_array(PERIOD) is not None) +@pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) +def test_without_cache_opt_out(simulation) -> None: + simulation.calculate("output", period=PERIOD) + intermediate_cache = simulation.persons.get_holder("intermediate") + assert intermediate_cache.get_array(PERIOD) is not None -@pytest.mark.parametrize("simulation", [({'input': 1}, PERIOD)], indirect = True) -def test_with_cache_opt_out(simulation): +@pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) +def test_with_cache_opt_out(simulation) -> None: simulation.debug = True simulation.opt_out_cache = True - simulation.calculate('output', period = PERIOD) - intermediate_cache = simulation.persons.get_holder('intermediate') - assert(intermediate_cache.get_array(PERIOD) is None) + simulation.calculate("output", period=PERIOD) + intermediate_cache = simulation.persons.get_holder("intermediate") + assert intermediate_cache.get_array(PERIOD) is None -@pytest.mark.parametrize("simulation", [({'input': 1}, PERIOD)], indirect = True) -def test_with_no_blacklist(simulation): - simulation.calculate('output', period = PERIOD) - intermediate_cache = simulation.persons.get_holder('intermediate') - assert(intermediate_cache.get_array(PERIOD) is not None) +@pytest.mark.parametrize("simulation", [({"input": 1}, PERIOD)], indirect=True) +def test_with_no_blacklist(simulation) -> None: + simulation.calculate("output", period=PERIOD) + intermediate_cache = simulation.persons.get_holder("intermediate") + assert intermediate_cache.get_array(PERIOD) is not None diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py index 40d8bb3fc9..7fe63a8180 100644 --- a/tests/core/test_parameters.py +++ b/tests/core/test_parameters.py @@ -2,106 +2,135 @@ import pytest -from openfisca_core.parameters import ParameterNotFound, ParameterNode, ParameterNodeAtInstant, load_parameter_file +from openfisca_core.parameters import ( + ParameterNode, + ParameterNodeAtInstant, + ParameterNotFound, + load_parameter_file, +) -def test_get_at_instant(tax_benefit_system): +def test_get_at_instant(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters assert isinstance(parameters, ParameterNode), parameters - parameters_at_instant = parameters('2016-01-01') - assert isinstance(parameters_at_instant, ParameterNodeAtInstant), parameters_at_instant + parameters_at_instant = parameters("2016-01-01") + assert isinstance( + parameters_at_instant, + ParameterNodeAtInstant, + ), parameters_at_instant assert parameters_at_instant.taxes.income_tax_rate == 0.15 assert parameters_at_instant.benefits.basic_income == 600 -def test_param_values(tax_benefit_system): +def test_param_values(tax_benefit_system) -> None: dated_values = { - '2015-01-01': 0.15, - '2014-01-01': 0.14, - '2013-01-01': 0.13, - '2012-01-01': 0.16, - } + "2015-01-01": 0.15, + "2014-01-01": 0.14, + "2013-01-01": 0.13, + "2012-01-01": 0.16, + } for date, value in dated_values.items(): - assert tax_benefit_system.get_parameters_at_instant(date).taxes.income_tax_rate == value + assert ( + tax_benefit_system.get_parameters_at_instant(date).taxes.income_tax_rate + == value + ) -def test_param_before_it_is_defined(tax_benefit_system): +def test_param_before_it_is_defined(tax_benefit_system) -> None: with pytest.raises(ParameterNotFound): - tax_benefit_system.get_parameters_at_instant('1997-12-31').taxes.income_tax_rate + tax_benefit_system.get_parameters_at_instant("1997-12-31").taxes.income_tax_rate # The placeholder should have no effect on the parameter computation -def test_param_with_placeholder(tax_benefit_system): - assert tax_benefit_system.get_parameters_at_instant('2018-01-01').taxes.income_tax_rate == 0.15 +def test_param_with_placeholder(tax_benefit_system) -> None: + assert ( + tax_benefit_system.get_parameters_at_instant("2018-01-01").taxes.income_tax_rate + == 0.15 + ) -def test_stopped_parameter_before_end_value(tax_benefit_system): - assert tax_benefit_system.get_parameters_at_instant('2011-12-31').benefits.housing_allowance == 0.25 +def test_stopped_parameter_before_end_value(tax_benefit_system) -> None: + assert ( + tax_benefit_system.get_parameters_at_instant( + "2011-12-31", + ).benefits.housing_allowance + == 0.25 + ) -def test_stopped_parameter_after_end_value(tax_benefit_system): +def test_stopped_parameter_after_end_value(tax_benefit_system) -> None: with pytest.raises(ParameterNotFound): - tax_benefit_system.get_parameters_at_instant('2016-12-01').benefits.housing_allowance + tax_benefit_system.get_parameters_at_instant( + "2016-12-01", + ).benefits.housing_allowance -def test_parameter_for_period(tax_benefit_system): +def test_parameter_for_period(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate assert income_tax_rate("2015") == income_tax_rate("2015-01-01") -def test_wrong_value(tax_benefit_system): +def test_wrong_value(tax_benefit_system) -> None: income_tax_rate = tax_benefit_system.parameters.taxes.income_tax_rate with pytest.raises(ValueError): income_tax_rate("test") -def test_parameter_repr(tax_benefit_system): +def test_parameter_repr(tax_benefit_system) -> None: parameters = tax_benefit_system.parameters - tf = tempfile.NamedTemporaryFile(delete = False) - tf.write(repr(parameters).encode('utf-8')) + tf = tempfile.NamedTemporaryFile(delete=False) + tf.write(repr(parameters).encode("utf-8")) tf.close() - tf_parameters = load_parameter_file(file_path = tf.name) + tf_parameters = load_parameter_file(file_path=tf.name) assert repr(parameters) == repr(tf_parameters) -def test_parameters_metadata(tax_benefit_system): +def test_parameters_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.basic_income - assert parameter.metadata['reference'] == 'https://law.gov.example/basic-income/amount' - assert parameter.metadata['unit'] == 'currency-EUR' - assert parameter.values_list[0].metadata['reference'] == 'https://law.gov.example/basic-income/amount/2015-12' - assert parameter.values_list[0].metadata['unit'] == 'currency-EUR' + assert ( + parameter.metadata["reference"] == "https://law.gov.example/basic-income/amount" + ) + assert parameter.metadata["unit"] == "currency-EUR" + assert ( + parameter.values_list[0].metadata["reference"] + == "https://law.gov.example/basic-income/amount/2015-12" + ) + assert parameter.values_list[0].metadata["unit"] == "currency-EUR" scale = tax_benefit_system.parameters.taxes.social_security_contribution - assert scale.metadata['threshold_unit'] == 'currency-EUR' - assert scale.metadata['rate_unit'] == '/1' + assert scale.metadata["threshold_unit"] == "currency-EUR" + assert scale.metadata["rate_unit"] == "/1" -def test_parameter_node_metadata(tax_benefit_system): +def test_parameter_node_metadata(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits - assert parameter.description == 'Social benefits' + assert parameter.description == "Social benefits" parameter_2 = tax_benefit_system.parameters.taxes.housing_tax - assert parameter_2.description == 'Housing tax' + assert parameter_2.description == "Housing tax" -def test_parameter_documentation(tax_benefit_system): +def test_parameter_documentation(tax_benefit_system) -> None: parameter = tax_benefit_system.parameters.benefits.housing_allowance - assert parameter.documentation == 'A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists.\n' + assert ( + parameter.documentation + == "A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists.\n" + ) -def test_get_descendants(tax_benefit_system): - all_parameters = {parameter.name for parameter in tax_benefit_system.parameters.get_descendants()} - assert all_parameters.issuperset({'taxes', 'taxes.housing_tax', 'taxes.housing_tax.minimal_amount'}) +def test_get_descendants(tax_benefit_system) -> None: + all_parameters = { + parameter.name for parameter in tax_benefit_system.parameters.get_descendants() + } + assert all_parameters.issuperset( + {"taxes", "taxes.housing_tax", "taxes.housing_tax.minimal_amount"}, + ) -def test_name(): +def test_name() -> None: parameter_data = { "description": "Parameter indexed by a numeric key", - "2010": { - "values": { - '2006-01-01': 0.0075 - } - } - } - parameter = ParameterNode('root', data = parameter_data) + "2010": {"values": {"2006-01-01": 0.0075}}, + } + parameter = ParameterNode("root", data=parameter_data) assert parameter.children["2010"].name == "root.2010" diff --git a/tests/core/test_periods.py b/tests/core/test_periods.py deleted file mode 100644 index 2c125d527c..0000000000 --- a/tests/core/test_periods.py +++ /dev/null @@ -1,203 +0,0 @@ -# -*- coding: utf-8 -*- - - -import pytest - -from openfisca_core.periods import Period, Instant, YEAR, MONTH, DAY, period - -first_jan = Instant((2014, 1, 1)) -first_march = Instant((2014, 3, 1)) - - -''' -Test Period -> String -''' - - -# Years - -def test_year(): - assert str(Period((YEAR, first_jan, 1))) == '2014' - - -def test_12_months_is_a_year(): - assert str(Period((MONTH, first_jan, 12))) == '2014' - - -def test_rolling_year(): - assert str(Period((MONTH, first_march, 12))) == 'year:2014-03' - assert str(Period((YEAR, first_march, 1))) == 'year:2014-03' - - -def test_several_years(): - assert str(Period((YEAR, first_jan, 3))) == 'year:2014:3' - assert str(Period((YEAR, first_march, 3))) == 'year:2014-03:3' - - -# Months - -def test_month(): - assert str(Period((MONTH, first_jan, 1))) == '2014-01' - - -def test_several_months(): - assert str(Period((MONTH, first_jan, 3))) == 'month:2014-01:3' - assert str(Period((MONTH, first_march, 3))) == 'month:2014-03:3' - - -# Days - -def test_day(): - assert str(Period((DAY, first_jan, 1))) == '2014-01-01' - - -def test_several_days(): - assert str(Period((DAY, first_jan, 3))) == 'day:2014-01-01:3' - assert str(Period((DAY, first_march, 3))) == 'day:2014-03-01:3' - - -''' -Test String -> Period -''' - - -# Years - -def test_parsing_year(): - assert period('2014') == Period((YEAR, first_jan, 1)) - - -def test_parsing_rolling_year(): - assert period('year:2014-03') == Period((YEAR, first_march, 1)) - - -def test_parsing_several_years(): - assert period('year:2014:2') == Period((YEAR, first_jan, 2)) - - -def test_wrong_syntax_several_years(): - with pytest.raises(ValueError): - period('2014:2') - - -# Months - -def test_parsing_month(): - assert period('2014-01') == Period((MONTH, first_jan, 1)) - - -def test_parsing_several_months(): - assert period('month:2014-03:3') == Period((MONTH, first_march, 3)) - - -def test_wrong_syntax_several_months(): - with pytest.raises(ValueError): - period('2014-3:3') - - -# Days - -def test_parsing_day(): - assert period('2014-01-01') == Period((DAY, first_jan, 1)) - - -def test_parsing_several_days(): - assert period('day:2014-03-01:3') == Period((DAY, first_march, 3)) - - -def test_wrong_syntax_several_days(): - with pytest.raises(ValueError): - period('2014-2-3:2') - - -def test_day_size_in_days(): - assert Period(('day', Instant((2014, 12, 31)), 1)).size_in_days == 1 - - -def test_3_day_size_in_days(): - assert Period(('day', Instant((2014, 12, 31)), 3)).size_in_days == 3 - - -def test_month_size_in_days(): - assert Period(('month', Instant((2014, 12, 1)), 1)).size_in_days == 31 - - -def test_leap_month_size_in_days(): - assert Period(('month', Instant((2012, 2, 3)), 1)).size_in_days == 29 - - -def test_3_month_size_in_days(): - assert Period(('month', Instant((2013, 1, 3)), 3)).size_in_days == 31 + 28 + 31 - - -def test_leap_3_month_size_in_days(): - assert Period(('month', Instant((2012, 1, 3)), 3)).size_in_days == 31 + 29 + 31 - - -def test_year_size_in_days(): - assert Period(('year', Instant((2014, 12, 1)), 1)).size_in_days == 365 - - -def test_leap_year_size_in_days(): - assert Period(('year', Instant((2012, 1, 1)), 1)).size_in_days == 366 - - -def test_2_years_size_in_days(): - assert Period(('year', Instant((2014, 1, 1)), 2)).size_in_days == 730 - -# Misc - - -def test_wrong_date(): - with pytest.raises(ValueError): - period("2006-31-03") - - -def test_ambiguous_period(): - with pytest.raises(ValueError): - period('month:2014') - - -def test_deprecated_signature(): - with pytest.raises(TypeError): - period(MONTH, 2014) - - -def test_wrong_argument(): - with pytest.raises(ValueError): - period({}) - - -def test_wrong_argument_1(): - with pytest.raises(ValueError): - period([]) - - -def test_none(): - with pytest.raises(ValueError): - period(None) - - -def test_empty_string(): - with pytest.raises(ValueError): - period('') - - -@pytest.mark.parametrize("test", [ - (period('year:2014:2'), YEAR, 2, period('2014'), period('2015')), - (period(2017), MONTH, 12, period('2017-01'), period('2017-12')), - (period('year:2014:2'), MONTH, 24, period('2014-01'), period('2015-12')), - (period('month:2014-03:3'), MONTH, 3, period('2014-03'), period('2014-05')), - (period(2017), DAY, 365, period('2017-01-01'), period('2017-12-31')), - (period('year:2014:2'), DAY, 730, period('2014-01-01'), period('2015-12-31')), - (period('month:2014-03:3'), DAY, 92, period('2014-03-01'), period('2014-05-31')), - ]) -def test_subperiods(test): - - def check_subperiods(period, unit, length, first, last): - subperiods = period.get_subperiods(unit) - assert len(subperiods) == length - assert subperiods[0] == first - assert subperiods[-1] == last - - check_subperiods(*test) diff --git a/tests/core/test_projectors.py b/tests/core/test_projectors.py new file mode 100644 index 0000000000..c62e49d3a7 --- /dev/null +++ b/tests/core/test_projectors.py @@ -0,0 +1,349 @@ +import numpy + +from openfisca_core.entities import build_entity +from openfisca_core.indexed_enums import Enum +from openfisca_core.periods import DateUnit +from openfisca_core.simulations.simulation_builder import SimulationBuilder +from openfisca_core.taxbenefitsystems import TaxBenefitSystem +from openfisca_core.variables import Variable + + +def test_shortcut_to_containing_entity_provided() -> None: + """Tests that, when an entity provides a containing entity, + the shortcut to that containing entity is provided. + """ + person_entity = build_entity( + key="person", + plural="people", + label="A person", + is_person=True, + ) + family_entity = build_entity( + key="family", + plural="families", + label="A family (all members in the same household)", + containing_entities=["household"], + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + household_entity = build_entity( + key="household", + plural="households", + label="A household, containing one or more families", + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + + entities = [person_entity, family_entity, household_entity] + + system = TaxBenefitSystem(entities) + simulation = SimulationBuilder().build_from_dict(system, {}) + assert simulation.populations["family"].household.entity.key == "household" + + +def test_shortcut_to_containing_entity_not_provided() -> None: + """Tests that, when an entity doesn't provide a containing + entity, the shortcut to that containing entity is not provided. + """ + person_entity = build_entity( + key="person", + plural="people", + label="A person", + is_person=True, + ) + family_entity = build_entity( + key="family", + plural="families", + label="A family (all members in the same household)", + containing_entities=[], + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + household_entity = build_entity( + key="household", + plural="households", + label="A household, containing one or more families", + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + + entities = [person_entity, family_entity, household_entity] + + system = TaxBenefitSystem(entities) + simulation = SimulationBuilder().build_from_dict(system, {}) + try: + simulation.populations["family"].household + raise AssertionError + except AttributeError: + pass + + +def test_enum_projects_downwards() -> None: + """Test that an Enum-type household-level variable projects + values onto its members correctly. + """ + person = build_entity( + key="person", + plural="people", + label="A person", + is_person=True, + ) + household = build_entity( + key="household", + plural="households", + label="A household", + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + + entities = [person, household] + + system = TaxBenefitSystem(entities) + + class enum(Enum): + FIRST_OPTION = "First option" + SECOND_OPTION = "Second option" + + class household_enum_variable(Variable): + value_type = Enum + possible_values = enum + default_value = enum.FIRST_OPTION + entity = household + definition_period = DateUnit.ETERNITY + + class projected_enum_variable(Variable): + value_type = Enum + possible_values = enum + default_value = enum.FIRST_OPTION + entity = person + definition_period = DateUnit.ETERNITY + + def formula(self, period): + return self.household("household_enum_variable", period) + + system.add_variables(household_enum_variable, projected_enum_variable) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": {"person1": {}, "person2": {}, "person3": {}}, + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + "household_enum_variable": {"eternity": "SECOND_OPTION"}, + }, + }, + }, + ) + + assert ( + simulation.calculate("projected_enum_variable", "2021-01-01").decode_to_str() + == numpy.array(["SECOND_OPTION"] * 3) + ).all() + + +def test_enum_projects_upwards() -> None: + """Test that an Enum-type person-level variable projects + values onto its household (from the first person) correctly. + """ + person = build_entity( + key="person", + plural="people", + label="A person", + is_person=True, + ) + household = build_entity( + key="household", + plural="households", + label="A household", + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + + entities = [person, household] + + system = TaxBenefitSystem(entities) + + class enum(Enum): + FIRST_OPTION = "First option" + SECOND_OPTION = "Second option" + + class household_projected_variable(Variable): + value_type = Enum + possible_values = enum + default_value = enum.FIRST_OPTION + entity = household + definition_period = DateUnit.ETERNITY + + def formula(self, period): + return self.value_from_first_person( + self.members("person_enum_variable", period), + ) + + class person_enum_variable(Variable): + value_type = Enum + possible_values = enum + default_value = enum.FIRST_OPTION + entity = person + definition_period = DateUnit.ETERNITY + + system.add_variables(household_projected_variable, person_enum_variable) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": { + "person1": {"person_enum_variable": {"ETERNITY": "SECOND_OPTION"}}, + "person2": {}, + "person3": {}, + }, + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + }, + }, + }, + ) + + assert ( + simulation.calculate( + "household_projected_variable", + "2021-01-01", + ).decode_to_str() + == numpy.array(["SECOND_OPTION"]) + ).all() + + +def test_enum_projects_between_containing_groups() -> None: + """Test that an Enum-type person-level variable projects + values onto its household (from the first person) correctly. + """ + person_entity = build_entity( + key="person", + plural="people", + label="A person", + is_person=True, + ) + family_entity = build_entity( + key="family", + plural="families", + label="A family (all members in the same household)", + containing_entities=["household"], + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + household_entity = build_entity( + key="household", + plural="households", + label="A household, containing one or more families", + roles=[ + { + "key": "member", + "plural": "members", + "label": "Member", + }, + ], + ) + + entities = [person_entity, family_entity, household_entity] + + system = TaxBenefitSystem(entities) + + class enum(Enum): + FIRST_OPTION = "First option" + SECOND_OPTION = "Second option" + + class household_level_variable(Variable): + value_type = Enum + possible_values = enum + default_value = enum.FIRST_OPTION + entity = household_entity + definition_period = DateUnit.ETERNITY + + class projected_family_level_variable(Variable): + value_type = Enum + possible_values = enum + default_value = enum.FIRST_OPTION + entity = family_entity + definition_period = DateUnit.ETERNITY + + def formula(self, period): + return self.household("household_level_variable", period) + + class decoded_projected_family_level_variable(Variable): + value_type = str + entity = family_entity + definition_period = DateUnit.ETERNITY + + def formula(self, period): + return self.household("household_level_variable", period).decode_to_str() + + system.add_variables( + household_level_variable, + projected_family_level_variable, + decoded_projected_family_level_variable, + ) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": {"person1": {}, "person2": {}, "person3": {}}, + "families": { + "family1": {"members": ["person1", "person2"]}, + "family2": {"members": ["person3"]}, + }, + "households": { + "household1": { + "members": ["person1", "person2", "person3"], + "household_level_variable": {"eternity": "SECOND_OPTION"}, + }, + }, + }, + ) + + assert ( + simulation.calculate( + "projected_family_level_variable", + "2021-01-01", + ).decode_to_str() + == numpy.array(["SECOND_OPTION"]) + ).all() + assert ( + simulation.calculate("decoded_projected_family_level_variable", "2021-01-01") + == numpy.array(["SECOND_OPTION"]) + ).all() diff --git a/tests/core/test_reforms.py b/tests/core/test_reforms.py index 8735cee18f..1f31bcde2a 100644 --- a/tests/core/test_reforms.py +++ b/tests/core/test_reforms.py @@ -2,12 +2,14 @@ import pytest -from openfisca_core import periods -from openfisca_core.periods import Instant -from openfisca_core.tools import assert_near -from openfisca_core.parameters import ValuesHistory, ParameterNode from openfisca_country_template.entities import Household, Person -from openfisca_core.model_api import * # noqa analysis:ignore + +from openfisca_core import holders, periods, simulations +from openfisca_core.parameters import ParameterNode, ValuesHistory +from openfisca_core.periods import DateUnit, Instant +from openfisca_core.reforms import Reform +from openfisca_core.tools import assert_near +from openfisca_core.variables import Variable class goes_to_school(Variable): @@ -15,333 +17,463 @@ class goes_to_school(Variable): default_value = True entity = Person label = "The person goes to school (only relevant for children)" - definition_period = MONTH + definition_period = DateUnit.MONTH class WithBasicIncomeNeutralized(Reform): - def apply(self): - self.neutralize_variable('basic_income') + def apply(self) -> None: + self.neutralize_variable("basic_income") -@pytest.fixture(scope = "module", autouse = True) -def add_variables_to_tax_benefit_system(tax_benefit_system): +@pytest.fixture(scope="module", autouse=True) +def add_variables_to_tax_benefit_system(tax_benefit_system) -> None: tax_benefit_system.add_variables(goes_to_school) -def test_formula_neutralization(make_simulation, tax_benefit_system): +def test_formula_neutralization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) - period = '2017-01' + period = "2017-01" simulation = make_simulation(reform.base_tax_benefit_system, {}, period) simulation.debug = True - basic_income = simulation.calculate('basic_income', period = period) + basic_income = simulation.calculate("basic_income", period=period) assert_near(basic_income, 600) - disposable_income = simulation.calculate('disposable_income', period = period) + disposable_income = simulation.calculate("disposable_income", period=period) assert disposable_income > 0 reform_simulation = make_simulation(reform, {}, period) reform_simulation.debug = True - basic_income_reform = reform_simulation.calculate('basic_income', period = '2013-01') - assert_near(basic_income_reform, 0, absolute_error_margin = 0) - disposable_income_reform = reform_simulation.calculate('disposable_income', period = period) + basic_income_reform = reform_simulation.calculate("basic_income", period="2013-01") + assert_near(basic_income_reform, 0, absolute_error_margin=0) + disposable_income_reform = reform_simulation.calculate( + "disposable_income", + period=period, + ) assert_near(disposable_income_reform, 0) -def test_neutralization_variable_with_default_value(make_simulation, tax_benefit_system): +def test_neutralization_variable_with_default_value( + make_simulation, + tax_benefit_system, +) -> None: class test_goes_to_school_neutralization(Reform): - def apply(self): - self.neutralize_variable('goes_to_school') + def apply(self) -> None: + self.neutralize_variable("goes_to_school") reform = test_goes_to_school_neutralization(tax_benefit_system) period = "2017-01" simulation = make_simulation(reform.base_tax_benefit_system, {}, period) - goes_to_school = simulation.calculate('goes_to_school', period) - assert_near(goes_to_school, [True], absolute_error_margin = 0) + goes_to_school = simulation.calculate("goes_to_school", period) + assert_near(goes_to_school, [True], absolute_error_margin=0) -def test_neutralization_optimization(make_simulation, tax_benefit_system): +def test_neutralization_optimization(make_simulation, tax_benefit_system) -> None: reform = WithBasicIncomeNeutralized(tax_benefit_system) - period = '2017-01' + period = "2017-01" simulation = make_simulation(reform, {}, period) simulation.debug = True - simulation.calculate('basic_income', period = '2013-01') - simulation.calculate_add('basic_income', period = '2013') + simulation.calculate("basic_income", period="2013-01") + simulation.calculate_add("basic_income", period="2013") # As basic_income is neutralized, it should not be cached - basic_income_holder = simulation.persons.get_holder('basic_income') + basic_income_holder = simulation.persons.get_holder("basic_income") assert basic_income_holder.get_known_periods() == [] -def test_input_variable_neutralization(make_simulation, tax_benefit_system): - +def test_input_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_salary_neutralization(Reform): - def apply(self): - self.neutralize_variable('salary') + def apply(self) -> None: + self.neutralize_variable("salary") reform = test_salary_neutralization(tax_benefit_system) - period = '2017-01' + period = "2017-01" reform = test_salary_neutralization(tax_benefit_system) with warnings.catch_warnings(record=True) as raised_warnings: - reform_simulation = make_simulation(reform, {'salary': [1200, 1000]}, period) - assert 'You cannot set a value for the variable' in raised_warnings[0].message.args[0] - salary = reform_simulation.calculate('salary', period) - assert_near(salary, [0, 0],) - disposable_income_reform = reform_simulation.calculate('disposable_income', period = period) + reform_simulation = make_simulation(reform, {"salary": [1200, 1000]}, period) + assert ( + "You cannot set a value for the variable" + in raised_warnings[0].message.args[0] + ) + salary = reform_simulation.calculate("salary", period) + assert_near( + salary, + [0, 0], + ) + disposable_income_reform = reform_simulation.calculate( + "disposable_income", + period=period, + ) assert_near(disposable_income_reform, [600, 600]) -def test_permanent_variable_neutralization(make_simulation, tax_benefit_system): - +def test_permanent_variable_neutralization(make_simulation, tax_benefit_system) -> None: class test_date_naissance_neutralization(Reform): - def apply(self): - self.neutralize_variable('birth') + def apply(self) -> None: + self.neutralize_variable("birth") reform = test_date_naissance_neutralization(tax_benefit_system) - period = '2017-01' - simulation = make_simulation(reform.base_tax_benefit_system, {'birth': '1980-01-01'}, period) + period = "2017-01" + simulation = make_simulation( + reform.base_tax_benefit_system, + {"birth": "1980-01-01"}, + period, + ) with warnings.catch_warnings(record=True) as raised_warnings: - reform_simulation = make_simulation(reform, {'birth': '1980-01-01'}, period) - assert 'You cannot set a value for the variable' in raised_warnings[0].message.args[0] - assert str(simulation.calculate('birth', None)[0]) == '1980-01-01' - assert str(reform_simulation.calculate('birth', None)[0]) == '1970-01-01' - - -def test_update_items(): - def check_update_items(description, value_history, start_instant, stop_instant, value, expected_items): - value_history.update(period=None, start=start_instant, stop=stop_instant, value=value) + reform_simulation = make_simulation(reform, {"birth": "1980-01-01"}, period) + assert ( + "You cannot set a value for the variable" + in raised_warnings[0].message.args[0] + ) + assert str(simulation.calculate("birth", None)[0]) == "1980-01-01" + assert str(reform_simulation.calculate("birth", None)[0]) == "1970-01-01" + + +def test_update_items() -> None: + def check_update_items( + description, + value_history, + start_instant, + stop_instant, + value, + expected_items, + ) -> None: + value_history.update( + period=None, + start=start_instant, + stop=stop_instant, + value=value, + ) assert value_history == expected_items check_update_items( - 'Replace an item by a new item', - ValuesHistory('dummy_name', {"2013-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), + "Replace an item by a new item", + ValuesHistory( + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, + ), periods.period(2013).start, periods.period(2013).stop, 1.0, - ValuesHistory('dummy_name', {"2013-01-01": {'value': 1.0}, "2014-01-01": {'value': None}}), - ) + ValuesHistory( + "dummy_name", + {"2013-01-01": {"value": 1.0}, "2014-01-01": {"value": None}}, + ), + ) check_update_items( - 'Replace an item by a new item in a list of items, the last being open', - ValuesHistory('dummy_name', {"2014-01-01": {'value': 9.53}, "2015-01-01": {'value': 9.61}, "2016-01-01": {'value': 9.67}}), + "Replace an item by a new item in a list of items, the last being open", + ValuesHistory( + "dummy_name", + { + "2014-01-01": {"value": 9.53}, + "2015-01-01": {"value": 9.61}, + "2016-01-01": {"value": 9.67}, + }, + ), periods.period(2015).start, periods.period(2015).stop, 1.0, - ValuesHistory('dummy_name', {"2014-01-01": {'value': 9.53}, "2015-01-01": {'value': 1.0}, "2016-01-01": {'value': 9.67}}), - ) + ValuesHistory( + "dummy_name", + { + "2014-01-01": {"value": 9.53}, + "2015-01-01": {"value": 1.0}, + "2016-01-01": {"value": 9.67}, + }, + ), + ) check_update_items( - 'Open the stop instant to the future', - ValuesHistory('dummy_name', {"2013-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), + "Open the stop instant to the future", + ValuesHistory( + "dummy_name", + {"2013-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, + ), periods.period(2013).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2013-01-01": {'value': 1.0}}), - ) + ValuesHistory("dummy_name", {"2013-01-01": {"value": 1.0}}), + ) check_update_items( - 'Insert a new item in the middle of an existing item', - ValuesHistory('dummy_name', {"2010-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), + "Insert a new item in the middle of an existing item", + ValuesHistory( + "dummy_name", + {"2010-01-01": {"value": 0.0}, "2014-01-01": {"value": None}}, + ), periods.period(2011).start, periods.period(2011).stop, 1.0, - ValuesHistory('dummy_name', {"2010-01-01": {'value': 0.0}, "2011-01-01": {'value': 1.0}, "2012-01-01": {'value': 0.0}, "2014-01-01": {'value': None}}), - ) + ValuesHistory( + "dummy_name", + { + "2010-01-01": {"value": 0.0}, + "2011-01-01": {"value": 1.0}, + "2012-01-01": {"value": 0.0}, + "2014-01-01": {"value": None}, + }, + ), + ) check_update_items( - 'Insert a new open item coming after the last open item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item coming after the last open item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2015).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}, "2015-01-01": {'value': 1.0}}), - ) + ValuesHistory( + "dummy_name", + { + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 0.14}, + "2015-01-01": {"value": 1.0}, + }, + ), + ) check_update_items( - 'Insert a new item starting at the same date than the last open item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new item starting at the same date than the last open item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2014).start, periods.period(2014).stop, 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 1.0}, "2015-01-01": {'value': 0.14}}), - ) + ValuesHistory( + "dummy_name", + { + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 1.0}, + "2015-01-01": {"value": 0.14}, + }, + ), + ) check_update_items( - 'Insert a new open item starting at the same date than the last open item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item starting at the same date than the last open item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2014).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 1.0}}), - ) + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 1.0}}, + ), + ) check_update_items( - 'Insert a new item coming before the first item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new item coming before the first item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2005).start, periods.period(2005).stop, 1.0, - ValuesHistory('dummy_name', {"2005-01-01": {'value': 1.0}, "2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), - ) + ValuesHistory( + "dummy_name", + { + "2005-01-01": {"value": 1.0}, + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 0.14}, + }, + ), + ) check_update_items( - 'Insert a new item coming before the first item with a hole', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new item coming before the first item with a hole", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2003).start, periods.period(2003).stop, 1.0, - ValuesHistory('dummy_name', {"2003-01-01": {'value': 1.0}, "2004-01-01": {'value': None}, "2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), - ) + ValuesHistory( + "dummy_name", + { + "2003-01-01": {"value": 1.0}, + "2004-01-01": {"value": None}, + "2006-01-01": {"value": 0.055}, + "2014-01-01": {"value": 0.14}, + }, + ), + ) check_update_items( - 'Insert a new open item starting before the start date of the first item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item starting before the start date of the first item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2005).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2005-01-01": {'value': 1.0}}), - ) + ValuesHistory("dummy_name", {"2005-01-01": {"value": 1.0}}), + ) check_update_items( - 'Insert a new open item starting at the same date than the first item', - ValuesHistory('dummy_name', {"2006-01-01": {'value': 0.055}, "2014-01-01": {'value': 0.14}}), + "Insert a new open item starting at the same date than the first item", + ValuesHistory( + "dummy_name", + {"2006-01-01": {"value": 0.055}, "2014-01-01": {"value": 0.14}}, + ), periods.period(2006).start, None, # stop instant 1.0, - ValuesHistory('dummy_name', {"2006-01-01": {'value': 1.0}}), - ) + ValuesHistory("dummy_name", {"2006-01-01": {"value": 1.0}}), + ) -def test_add_variable(make_simulation, tax_benefit_system): +def test_add_variable(make_simulation, tax_benefit_system) -> None: class new_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula(household, period): - return household.empty_array() + 10 + def formula(self, period): + return self.empty_array() + 10 class test_add_variable(Reform): - - def apply(self): + def apply(self) -> None: self.add_variable(new_variable) reform = test_add_variable(tax_benefit_system) - assert tax_benefit_system.get_variable('new_variable') is None + assert tax_benefit_system.get_variable("new_variable") is None reform_simulation = make_simulation(reform, {}, 2013) reform_simulation.debug = True - new_variable1 = reform_simulation.calculate('new_variable', period = '2013-01') - assert_near(new_variable1, 10, absolute_error_margin = 0) + new_variable1 = reform_simulation.calculate("new_variable", period="2013-01") + assert_near(new_variable1, 10, absolute_error_margin=0) -def test_add_dated_variable(make_simulation, tax_benefit_system): +def test_add_dated_variable(make_simulation, tax_benefit_system) -> None: class new_dated_variable(Variable): value_type = int label = "Nouvelle variable introduite par la réforme" entity = Household - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula_2010_01_01(household, period): - return household.empty_array() + 10 + def formula_2010_01_01(self, period): + return self.empty_array() + 10 - def formula_2011_01_01(household, period): - return household.empty_array() + 15 + def formula_2011_01_01(self, period): + return self.empty_array() + 15 class test_add_variable(Reform): - def apply(self): + def apply(self) -> None: self.add_variable(new_dated_variable) reform = test_add_variable(tax_benefit_system) - reform_simulation = make_simulation(reform, {}, '2013-01') + reform_simulation = make_simulation(reform, {}, "2013-01") reform_simulation.debug = True - new_dated_variable1 = reform_simulation.calculate('new_dated_variable', period = '2013-01') - assert_near(new_dated_variable1, 15, absolute_error_margin = 0) - + new_dated_variable1 = reform_simulation.calculate( + "new_dated_variable", + period="2013-01", + ) + assert_near(new_dated_variable1, 15, absolute_error_margin=0) -def test_update_variable(make_simulation, tax_benefit_system): +def test_update_variable(make_simulation, tax_benefit_system) -> None: class disposable_income(Variable): - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.update_variable(disposable_income) reform = test_update_variable(tax_benefit_system) - disposable_income_reform = reform.get_variable('disposable_income') - disposable_income_baseline = tax_benefit_system.get_variable('disposable_income') + disposable_income_reform = reform.get_variable("disposable_income") + disposable_income_baseline = tax_benefit_system.get_variable("disposable_income") assert disposable_income_reform is not None - assert disposable_income_reform.entity.plural == disposable_income_baseline.entity.plural + assert ( + disposable_income_reform.entity.plural + == disposable_income_baseline.entity.plural + ) assert disposable_income_reform.name == disposable_income_baseline.name assert disposable_income_reform.label == disposable_income_baseline.label reform_simulation = make_simulation(reform, {}, 2018) - disposable_income1 = reform_simulation.calculate('disposable_income', period = '2018-01') - assert_near(disposable_income1, 10, absolute_error_margin = 0) - - disposable_income2 = reform_simulation.calculate('disposable_income', period = '2017-01') + disposable_income1 = reform_simulation.calculate( + "disposable_income", + period="2018-01", + ) + assert_near(disposable_income1, 10, absolute_error_margin=0) + + disposable_income2 = reform_simulation.calculate( + "disposable_income", + period="2017-01", + ) # Before 2018, the former formula is used - assert(disposable_income2 > 100) - + assert disposable_income2 > 100 -def test_replace_variable(tax_benefit_system): +def test_replace_variable(tax_benefit_system) -> None: class disposable_income(Variable): - definition_period = MONTH + definition_period = DateUnit.MONTH entity = Person label = "Disposable income" value_type = float - def formula_2018(household, period): - return household.empty_array() + 10 + def formula_2018(self, period): + return self.empty_array() + 10 class test_update_variable(Reform): - def apply(self): + def apply(self) -> None: self.replace_variable(disposable_income) reform = test_update_variable(tax_benefit_system) - disposable_income_reform = reform.get_variable('disposable_income') - assert disposable_income_reform.get_formula('2017') is None + disposable_income_reform = reform.get_variable("disposable_income") + assert disposable_income_reform.get_formula("2017") is None -def test_wrong_reform(tax_benefit_system): +def test_wrong_reform(tax_benefit_system) -> None: class wrong_reform(Reform): # A Reform must implement an `apply` method pass - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 wrong_reform(tax_benefit_system) -def test_modify_parameters(tax_benefit_system): - +def test_modify_parameters(tax_benefit_system) -> None: def modify_parameters(reference_parameters): reform_parameters_subtree = ParameterNode( - 'new_node', - data = { - 'new_param': { - 'values': {"2000-01-01": {'value': True}, "2015-01-01": {'value': None}} + "new_node", + data={ + "new_param": { + "values": { + "2000-01-01": {"value": True}, + "2015-01-01": {"value": None}, }, }, - ) - reference_parameters.children['new_node'] = reform_parameters_subtree + }, + ) + reference_parameters.children["new_node"] = reform_parameters_subtree return reference_parameters class test_modify_parameters(Reform): - def apply(self): - self.modify_parameters(modifier_function = modify_parameters) + def apply(self) -> None: + self.modify_parameters(modifier_function=modify_parameters) reform = test_modify_parameters(tax_benefit_system) - parameters_new_node = reform.parameters.children['new_node'] + parameters_new_node = reform.parameters.children["new_node"] assert parameters_new_node is not None instant = Instant((2013, 1, 1)) @@ -349,15 +481,14 @@ def apply(self): assert parameters_at_instant.new_node.new_param is True -def test_attributes_conservation(tax_benefit_system): - +def test_attributes_conservation(tax_benefit_system) -> None: class some_variable(Variable): value_type = int entity = Person label = "Variable with many attributes" - definition_period = MONTH - set_input = set_input_divide_by_period - calculate_output = calculate_output_add + definition_period = DateUnit.MONTH + set_input = holders.set_input_divide_by_period + calculate_output = simulations.calculate_output_add tax_benefit_system.add_variable(some_variable) @@ -365,12 +496,12 @@ class reform(Reform): class some_variable(Variable): default_value = 10 - def apply(self): + def apply(self) -> None: self.update_variable(some_variable) reformed_tbs = reform(tax_benefit_system) - reform_variable = reformed_tbs.get_variable('some_variable') - baseline_variable = tax_benefit_system.get_variable('some_variable') + reform_variable = reformed_tbs.get_variable("some_variable") + baseline_variable = tax_benefit_system.get_variable("some_variable") assert reform_variable.value_type == baseline_variable.value_type assert reform_variable.entity == baseline_variable.entity assert reform_variable.label == baseline_variable.label @@ -379,18 +510,17 @@ def apply(self): assert reform_variable.calculate_output == baseline_variable.calculate_output -def test_formulas_removal(tax_benefit_system): +def test_formulas_removal(tax_benefit_system) -> None: class reform(Reform): - def apply(self): - + def apply(self) -> None: class basic_income(Variable): pass self.update_variable(basic_income) - self.variables['basic_income'].formulas.clear() + self.variables["basic_income"].formulas.clear() reformed_tbs = reform(tax_benefit_system) - reform_variable = reformed_tbs.get_variable('basic_income') - baseline_variable = tax_benefit_system.get_variable('basic_income') + reform_variable = reformed_tbs.get_variable("basic_income") + baseline_variable = tax_benefit_system.get_variable("basic_income") assert len(reform_variable.formulas) == 0 assert len(baseline_variable.formulas) > 0 diff --git a/tests/core/test_simulation_builder.py b/tests/core/test_simulation_builder.py index 4fefb61fda..b905b29b84 100644 --- a/tests/core/test_simulation_builder.py +++ b/tests/core/test_simulation_builder.py @@ -1,13 +1,15 @@ +from collections.abc import Iterable + import datetime -from typing import Iterable import pytest from openfisca_country_template import entities, situation_examples -from openfisca_core import periods, tools +from openfisca_core import tools from openfisca_core.errors import SituationParsingError from openfisca_core.indexed_enums import Enum +from openfisca_core.periods import DateUnit from openfisca_core.populations import Population from openfisca_core.simulations import Simulation, SimulationBuilder from openfisca_core.tools import test_runner @@ -16,13 +18,12 @@ @pytest.fixture def int_variable(persons): - class intvar(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = int entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return intvar() @@ -30,13 +31,12 @@ def __init__(self): @pytest.fixture def date_variable(persons): - class datevar(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = datetime.date entity = persons - def __init__(self): + def __init__(self) -> None: super().__init__() return datevar() @@ -44,213 +44,336 @@ def __init__(self): @pytest.fixture def enum_variable(): - class TestEnum(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = Enum - dtype = 'O' - default_value = '0' + dtype = "O" + default_value = "0" is_neutralized = False set_input = None - possible_values = Enum('foo', 'bar') + possible_values = Enum("foo", "bar") name = "enum" - def __init__(self): + def __init__(self) -> None: pass return TestEnum() -def test_build_default_simulation(tax_benefit_system): - one_person_simulation = SimulationBuilder().build_default_simulation(tax_benefit_system, 1) +def test_build_default_simulation(tax_benefit_system) -> None: + one_person_simulation = SimulationBuilder().build_default_simulation( + tax_benefit_system, + 1, + ) assert one_person_simulation.persons.count == 1 assert one_person_simulation.household.count == 1 assert one_person_simulation.household.members_entity_id == [0] - assert one_person_simulation.household.members_role == entities.Household.FIRST_PARENT + assert ( + one_person_simulation.household.members_role == entities.Household.FIRST_PARENT + ) - several_persons_simulation = SimulationBuilder().build_default_simulation(tax_benefit_system, 4) + several_persons_simulation = SimulationBuilder().build_default_simulation( + tax_benefit_system, + 4, + ) assert several_persons_simulation.persons.count == 4 assert several_persons_simulation.household.count == 4 - assert (several_persons_simulation.household.members_entity_id == [0, 1, 2, 3]).all() - assert (several_persons_simulation.household.members_role == entities.Household.FIRST_PARENT).all() + assert ( + several_persons_simulation.household.members_entity_id == [0, 1, 2, 3] + ).all() + assert ( + several_persons_simulation.household.members_role + == entities.Household.FIRST_PARENT + ).all() -def test_explicit_singular_entities(tax_benefit_system): +def test_explicit_singular_entities(tax_benefit_system) -> None: assert SimulationBuilder().explicit_singular_entities( tax_benefit_system, - {'persons': {'Javier': {}}, 'household': {'parents': ['Javier']}} - ) == {'persons': {'Javier': {}}, 'households': {'household': {'parents': ['Javier']}}} + {"persons": {"Javier": {}}, "household": {"parents": ["Javier"]}}, + ) == { + "persons": {"Javier": {}}, + "households": {"household": {"parents": ["Javier"]}}, + } -def test_add_person_entity(persons): - persons_json = {'Alicia': {'salary': {}}, 'Javier': {}} +def test_add_person_entity(persons) -> None: + persons_json = {"Alicia": {"salary": {}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) - assert simulation_builder.get_count('persons') == 2 - assert simulation_builder.get_ids('persons') == ['Alicia', 'Javier'] + assert simulation_builder.get_count("persons") == 2 + assert simulation_builder.get_ids("persons") == ["Alicia", "Javier"] -def test_numeric_ids(persons): - persons_json = {1: {'salary': {}}, 2: {}} +def test_numeric_ids(persons) -> None: + persons_json = {1: {"salary": {}}, 2: {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) - assert simulation_builder.get_count('persons') == 2 - assert simulation_builder.get_ids('persons') == ['1', '2'] + assert simulation_builder.get_count("persons") == 2 + assert simulation_builder.get_ids("persons") == ["1", "2"] -def test_add_person_entity_with_values(persons): - persons_json = {'Alicia': {'salary': {'2018-11': 3000}}, 'Javier': {}} +def test_add_person_entity_with_values(persons) -> None: + persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) - tools.assert_near(simulation_builder.get_input('salary', '2018-11'), [3000, 0]) + tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period(persons): - persons_json = {'Alicia': {'salary': 3000}, 'Javier': {}} +def test_add_person_values_with_default_period(persons) -> None: + persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() - simulation_builder.set_default_period('2018-11') + simulation_builder.set_default_period("2018-11") simulation_builder.add_person_entity(persons, persons_json) - tools.assert_near(simulation_builder.get_input('salary', '2018-11'), [3000, 0]) + tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_person_values_with_default_period_old_syntax(persons): - persons_json = {'Alicia': {'salary': 3000}, 'Javier': {}} +def test_add_person_values_with_default_period_old_syntax(persons) -> None: + persons_json = {"Alicia": {"salary": 3000}, "Javier": {}} simulation_builder = SimulationBuilder() - simulation_builder.set_default_period('month:2018-11') + simulation_builder.set_default_period("month:2018-11") simulation_builder.add_person_entity(persons, persons_json) - tools.assert_near(simulation_builder.get_input('salary', '2018-11'), [3000, 0]) + tools.assert_near(simulation_builder.get_input("salary", "2018-11"), [3000, 0]) -def test_add_group_entity(households): +def test_add_group_entity(households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Sarah', 'Tom'], households, { - 'Household_1': {'parents': ['Alicia', 'Javier']}, - 'Household_2': {'parents': ['Tom'], 'children': ['Sarah']}, - }) - assert simulation_builder.get_count('households') == 2 - assert simulation_builder.get_ids('households') == ['Household_1', 'Household_2'] - assert simulation_builder.get_memberships('households') == [0, 0, 1, 1] - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'parent', 'child', 'parent'] - - -def test_add_group_entity_loose_syntax(households): + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Sarah", "Tom"], + households, + { + "Household_1": {"parents": ["Alicia", "Javier"]}, + "Household_2": {"parents": ["Tom"], "children": ["Sarah"]}, + }, + ) + assert simulation_builder.get_count("households") == 2 + assert simulation_builder.get_ids("households") == ["Household_1", "Household_2"] + assert simulation_builder.get_memberships("households") == [0, 0, 1, 1] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "parent", + "child", + "parent", + ] + + +def test_add_group_entity_loose_syntax(households) -> None: simulation_builder = SimulationBuilder() - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Sarah', '1'], households, { - 'Household_1': {'parents': ['Alicia', 'Javier']}, - 'Household_2': {'parents': 1, 'children': 'Sarah'}, - }) - assert simulation_builder.get_count('households') == 2 - assert simulation_builder.get_ids('households') == ['Household_1', 'Household_2'] - assert simulation_builder.get_memberships('households') == [0, 0, 1, 1] - assert [role.key for role in simulation_builder.get_roles('households')] == ['parent', 'parent', 'child', 'parent'] - - -def test_add_variable_value(persons): - salary = persons.variables.get("salary") + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Sarah", "1"], + households, + { + "Household_1": {"parents": ["Alicia", "Javier"]}, + "Household_2": {"parents": 1, "children": "Sarah"}, + }, + ) + assert simulation_builder.get_count("households") == 2 + assert simulation_builder.get_ids("households") == ["Household_1", "Household_2"] + assert simulation_builder.get_memberships("households") == [0, 0, 1, 1] + assert [role.key for role in simulation_builder.get_roles("households")] == [ + "parent", + "parent", + "child", + "parent", + ] + + +def test_add_variable_value(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', 3000) - input_array = simulation_builder.get_input('salary', '2018-11') + simulation_builder.entity_counts["persons"] = 1 + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + 3000, + ) + input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_add_variable_value_as_expression(persons): - salary = persons.variables.get("salary") +def test_add_variable_value_as_expression(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', '3 * 1000') - input_array = simulation_builder.get_input('salary', '2018-11') + simulation_builder.entity_counts["persons"] = 1 + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "3 * 1000", + ) + input_array = simulation_builder.get_input("salary", "2018-11") assert input_array[instance_index] == pytest.approx(3000) -def test_fail_on_wrong_data(persons): - salary = persons.variables.get("salary") +def test_fail_on_wrong_data(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', 'alicia') - assert excinfo.value.error == {'persons': {'Alicia': {'salary': {'2018-11': "Can't deal with value: expected type number, received 'alicia'."}}}} - - -def test_fail_on_ill_formed_expression(persons): - salary = persons.variables.get("salary") + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "alicia", + ) + assert excinfo.value.error == { + "persons": { + "Alicia": { + "salary": { + "2018-11": "Can't deal with value: expected type number, received 'alicia'.", + }, + }, + }, + } + + +def test_fail_on_ill_formed_expression(persons) -> None: + salary = persons.get_variable("salary") instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, salary, instance_index, 'Alicia', '2018-11', '2 * / 1000') - assert excinfo.value.error == {'persons': {'Alicia': {'salary': {'2018-11': "I couldn't understand '2 * / 1000' as a value for 'salary'"}}}} - - -def test_fail_on_integer_overflow(persons, int_variable): + simulation_builder.add_variable_value( + persons, + salary, + instance_index, + "Alicia", + "2018-11", + "2 * / 1000", + ) + assert excinfo.value.error == { + "persons": { + "Alicia": { + "salary": { + "2018-11": "I couldn't understand '2 * / 1000' as a value for 'salary'", + }, + }, + }, + } + + +def test_fail_on_integer_overflow(persons, int_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, int_variable, instance_index, 'Alicia', '2018-11', 9223372036854775808) - assert excinfo.value.error == {'persons': {'Alicia': {'intvar': {'2018-11': "Can't deal with value: '9223372036854775808', it's too large for type 'integer'."}}}} - - -def test_fail_on_date_parsing(persons, date_variable): + simulation_builder.add_variable_value( + persons, + int_variable, + instance_index, + "Alicia", + "2018-11", + 9223372036854775808, + ) + assert excinfo.value.error == { + "persons": { + "Alicia": { + "intvar": { + "2018-11": "Can't deal with value: '9223372036854775808', it's too large for type 'integer'.", + }, + }, + }, + } + + +def test_fail_on_date_parsing(persons, date_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError) as excinfo: - simulation_builder.add_variable_value(persons, date_variable, instance_index, 'Alicia', '2018-11', '2019-02-30') - assert excinfo.value.error == {'persons': {'Alicia': {'datevar': {'2018-11': "Can't deal with date: '2019-02-30'."}}}} + simulation_builder.add_variable_value( + persons, + date_variable, + instance_index, + "Alicia", + "2018-11", + "2019-02-30", + ) + assert excinfo.value.error == { + "persons": { + "Alicia": {"datevar": {"2018-11": "Can't deal with date: '2019-02-30'."}}, + }, + } -def test_add_unknown_enum_variable_value(persons, enum_variable): +def test_add_unknown_enum_variable_value(persons, enum_variable) -> None: instance_index = 0 simulation_builder = SimulationBuilder() - simulation_builder.entity_counts['persons'] = 1 + simulation_builder.entity_counts["persons"] = 1 with pytest.raises(SituationParsingError): - simulation_builder.add_variable_value(persons, enum_variable, instance_index, 'Alicia', '2018-11', 'baz') + simulation_builder.add_variable_value( + persons, + enum_variable, + instance_index, + "Alicia", + "2018-11", + "baz", + ) -def test_finalize_person_entity(persons): - persons_json = {'Alicia': {'salary': {'2018-11': 3000}}, 'Javier': {}} +def test_finalize_person_entity(persons) -> None: + persons_json = {"Alicia": {"salary": {"2018-11": 3000}}, "Javier": {}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) population = Population(persons) simulation_builder.finalize_variables_init(population) - tools.assert_near(population.get_holder('salary').get_array('2018-11'), [3000, 0]) + tools.assert_near(population.get_holder("salary").get_array("2018-11"), [3000, 0]) assert population.count == 2 - assert population.ids == ['Alicia', 'Javier'] + assert population.ids == ["Alicia", "Javier"] -def test_canonicalize_period_keys(persons): - persons_json = {'Alicia': {'salary': {'year:2018-01': 100}}} +def test_canonicalize_period_keys(persons) -> None: + persons_json = {"Alicia": {"salary": {"year:2018-01": 100}}} simulation_builder = SimulationBuilder() simulation_builder.add_person_entity(persons, persons_json) population = Population(persons) simulation_builder.finalize_variables_init(population) - tools.assert_near(population.get_holder('salary').get_array('2018-12'), [100]) + tools.assert_near(population.get_holder("salary").get_array("2018-12"), [100]) -def test_finalize_households(tax_benefit_system): - simulation = Simulation(tax_benefit_system, tax_benefit_system.instantiate_entities()) +def test_finalize_households(tax_benefit_system) -> None: + simulation = Simulation( + tax_benefit_system, + tax_benefit_system.instantiate_entities(), + ) simulation_builder = SimulationBuilder() - simulation_builder.add_group_entity('persons', ['Alicia', 'Javier', 'Sarah', 'Tom'], simulation.household.entity, { - 'Household_1': {'parents': ['Alicia', 'Javier']}, - 'Household_2': {'parents': ['Tom'], 'children': ['Sarah']}, - }) + simulation_builder.add_group_entity( + "persons", + ["Alicia", "Javier", "Sarah", "Tom"], + simulation.household.entity, + { + "Household_1": {"parents": ["Alicia", "Javier"]}, + "Household_2": {"parents": ["Tom"], "children": ["Sarah"]}, + }, + ) simulation_builder.finalize_variables_init(simulation.household) tools.assert_near(simulation.household.members_entity_id, [0, 0, 1, 1]) - tools.assert_near(simulation.persons.has_role(entities.Household.PARENT), [True, True, False, True]) - - -def test_check_persons_to_allocate(): - entity_plural = 'familles' - persons_plural = 'individus' - person_id = 'Alicia' - entity_id = 'famille1' - role_id = 'parents' - persons_to_allocate = ['Alicia'] - persons_ids = ['Alicia'] + tools.assert_near( + simulation.persons.has_role(entities.Household.PARENT), + [True, True, False, True], + ) + + +def test_check_persons_to_allocate() -> None: + entity_plural = "familles" + persons_plural = "individus" + person_id = "Alicia" + entity_id = "famille1" + role_id = "parents" + persons_to_allocate = ["Alicia"] + persons_ids = ["Alicia"] index = 0 SimulationBuilder().check_persons_to_allocate( persons_plural, @@ -261,135 +384,196 @@ def test_check_persons_to_allocate(): role_id, persons_to_allocate, index, - ) + ) -def test_allocate_undeclared_person(): - entity_plural = 'familles' - persons_plural = 'individus' - person_id = 'Alicia' - entity_id = 'famille1' - role_id = 'parents' - persons_to_allocate = ['Alicia'] +def test_allocate_undeclared_person() -> None: + entity_plural = "familles" + persons_plural = "individus" + person_id = "Alicia" + entity_id = "famille1" + role_id = "parents" + persons_to_allocate = ["Alicia"] persons_ids = [] index = 0 with pytest.raises(SituationParsingError) as exception: SimulationBuilder().check_persons_to_allocate( - persons_plural, entity_plural, + persons_plural, + entity_plural, persons_ids, - person_id, entity_id, role_id, - persons_to_allocate, index) - assert exception.value.error == {'familles': {'famille1': {'parents': 'Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus.'}}} - - -def test_allocate_person_twice(): - entity_plural = 'familles' - persons_plural = 'individus' - person_id = 'Alicia' - entity_id = 'famille1' - role_id = 'parents' + person_id, + entity_id, + role_id, + persons_to_allocate, + index, + ) + assert exception.value.error == { + "familles": { + "famille1": { + "parents": "Unexpected value: Alicia. Alicia has been declared in famille1 parents, but has not been declared in individus.", + }, + }, + } + + +def test_allocate_person_twice() -> None: + entity_plural = "familles" + persons_plural = "individus" + person_id = "Alicia" + entity_id = "famille1" + role_id = "parents" persons_to_allocate = [] - persons_ids = ['Alicia'] + persons_ids = ["Alicia"] index = 0 with pytest.raises(SituationParsingError) as exception: SimulationBuilder().check_persons_to_allocate( - persons_plural, entity_plural, + persons_plural, + entity_plural, persons_ids, - person_id, entity_id, role_id, - persons_to_allocate, index) - assert exception.value.error == {'familles': {'famille1': {'parents': 'Alicia has been declared more than once in familles'}}} - - -def test_one_person_without_household(tax_benefit_system): - simulation_dict = {'persons': {'Alicia': {}}} - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, simulation_dict) + person_id, + entity_id, + role_id, + persons_to_allocate, + index, + ) + assert exception.value.error == { + "familles": { + "famille1": { + "parents": "Alicia has been declared more than once in familles", + }, + }, + } + + +def test_one_person_without_household(tax_benefit_system) -> None: + simulation_dict = {"persons": {"Alicia": {}}} + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + simulation_dict, + ) assert simulation.household.count == 1 - parents_in_households = simulation.household.nb_persons(role = entities.Household.PARENT) - assert parents_in_households.tolist() == [1] # household member default role is first_parent + parents_in_households = simulation.household.nb_persons( + role=entities.Household.PARENT, + ) + assert parents_in_households.tolist() == [ + 1, + ] # household member default role is first_parent -def test_some_person_without_household(tax_benefit_system): +def test_some_person_without_household(tax_benefit_system) -> None: input_yaml = """ persons: {'Alicia': {}, 'Bob': {}} household: {'parents': ['Alicia']} """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert simulation.household.count == 2 - parents_in_households = simulation.household.nb_persons(role = entities.Household.PARENT) - assert parents_in_households.tolist() == [1, 1] # household member default role is first_parent + parents_in_households = simulation.household.nb_persons( + role=entities.Household.PARENT, + ) + assert parents_in_households.tolist() == [ + 1, + 1, + ] # household member default role is first_parent -def test_nb_persons_in_households(tax_benefit_system): +def test_nb_persons_in_households(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) - simulation_builder.join_with_persons(household_instance, persons_households, ['first_parent'] * 5) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) + simulation_builder.join_with_persons( + household_instance, + persons_households, + ["first_parent"] * 5, + ) - persons_in_households = simulation_builder.nb_persons('household') + persons_in_households = simulation_builder.nb_persons("household") assert persons_in_households.tolist() == [1, 3, 1] -def test_nb_persons_no_role(tax_benefit_system): +def test_nb_persons_no_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) - simulation_builder.join_with_persons(household_instance, persons_households, ['first_parent'] * 5) - parents_in_households = household_instance.nb_persons(role = entities.Household.PARENT) + simulation_builder.join_with_persons( + household_instance, + persons_households, + ["first_parent"] * 5, + ) + parents_in_households = household_instance.nb_persons( + role=entities.Household.PARENT, + ) - assert parents_in_households.tolist() == [1, 3, 1] # household member default role is first_parent + assert parents_in_households.tolist() == [ + 1, + 3, + 1, + ] # household member default role is first_parent -def test_nb_persons_by_role(tax_benefit_system): +def test_nb_persons_by_role(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] - persons_households_roles: Iterable = ['child', 'first_parent', 'second_parent', 'first_parent', 'child'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] + persons_households_roles: Iterable = [ + "child", + "first_parent", + "second_parent", + "first_parent", + "child", + ] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( household_instance, persons_households, - persons_households_roles - ) - parents_in_households = household_instance.nb_persons(role = entities.Household.FIRST_PARENT) + persons_households_roles, + ) + parents_in_households = household_instance.nb_persons( + role=entities.Household.FIRST_PARENT, + ) assert parents_in_households.tolist() == [0, 1, 1] -def test_integral_roles(tax_benefit_system): +def test_integral_roles(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + households_ids: Iterable = ["c", "a", "b"] + persons_households: Iterable = ["c", "a", "a", "b", "a"] # Same roles as test_nb_persons_by_role persons_households_roles: Iterable = [2, 0, 1, 0, 2] simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) + simulation_builder.declare_person_entity("person", persons_ids) + household_instance = simulation_builder.declare_entity("household", households_ids) simulation_builder.join_with_persons( household_instance, persons_households, - persons_households_roles - ) - parents_in_households = household_instance.nb_persons(role = entities.Household.FIRST_PARENT) + persons_households_roles, + ) + parents_in_households = household_instance.nb_persons( + role=entities.Household.FIRST_PARENT, + ) assert parents_in_households.tolist() == [0, 1, 1] @@ -397,66 +581,79 @@ def test_integral_roles(tax_benefit_system): # Test Intégration -def test_from_person_variable_to_group(tax_benefit_system): +def test_from_person_variable_to_group(tax_benefit_system) -> None: persons_ids: Iterable = [2, 0, 1, 4, 3] - households_ids: Iterable = ['c', 'a', 'b'] + households_ids: Iterable = ["c", "a", "b"] - persons_households: Iterable = ['c', 'a', 'a', 'b', 'a'] + persons_households: Iterable = ["c", "a", "a", "b", "a"] persons_salaries: Iterable = [6000, 2000, 1000, 1500, 1500] households_rents = [1036.6667, 781.6667, 271.6667] - period = '2018-12' + period = "2018-12" simulation_builder = SimulationBuilder() simulation_builder.create_entities(tax_benefit_system) - simulation_builder.declare_person_entity('person', persons_ids) + simulation_builder.declare_person_entity("person", persons_ids) - household_instance = simulation_builder.declare_entity('household', households_ids) - simulation_builder.join_with_persons(household_instance, persons_households, ['first_parent'] * 5) + household_instance = simulation_builder.declare_entity("household", households_ids) + simulation_builder.join_with_persons( + household_instance, + persons_households, + ["first_parent"] * 5, + ) simulation = simulation_builder.build(tax_benefit_system) - simulation.set_input('salary', period, persons_salaries) - simulation.set_input('rent', period, households_rents) + simulation.set_input("salary", period, persons_salaries) + simulation.set_input("rent", period, households_rents) - total_taxes = simulation.calculate('total_taxes', period) + total_taxes = simulation.calculate("total_taxes", period) assert total_taxes == pytest.approx(households_rents) - assert total_taxes / simulation.calculate('rent', period) == pytest.approx(1) + assert total_taxes / simulation.calculate("rent", period) == pytest.approx(1) -def test_simulation(tax_benefit_system): +def test_simulation(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: 12000 """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert simulation.get_array("salary", "2016-10") == 12000 simulation.calculate("income_tax", "2016-10") simulation.calculate("total_taxes", "2016-10") -def test_vectorial_input(tax_benefit_system): +def test_vectorial_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) tools.assert_near(simulation.get_array("salary", "2016-10"), [12000, 20000]) simulation.calculate("income_tax", "2016-10") simulation.calculate("total_taxes", "2016-10") -def test_fully_specified_entities(tax_benefit_system): - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, situation_examples.couple) +def test_fully_specified_entities(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + situation_examples.couple, + ) assert simulation.household.count == 1 assert simulation.persons.count == 2 -def test_single_entity_shortcut(tax_benefit_system): +def test_single_entity_shortcut(tax_benefit_system) -> None: input_yaml = """ persons: Alicia: {} @@ -465,11 +662,14 @@ def test_single_entity_shortcut(tax_benefit_system): parents: [Alicia, Javier] """ - simulation = SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + simulation = SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert simulation.household.count == 1 -def test_order_preserved(tax_benefit_system): +def test_order_preserved(tax_benefit_system) -> None: input_yaml = """ persons: Javier: {} @@ -484,10 +684,10 @@ def test_order_preserved(tax_benefit_system): data = test_runner.yaml.safe_load(input_yaml) simulation = SimulationBuilder().build_from_dict(tax_benefit_system, data) - assert simulation.persons.ids == ['Javier', 'Alicia', 'Sarah', 'Tom'] + assert simulation.persons.ids == ["Javier", "Alicia", "Sarah", "Tom"] -def test_inconsistent_input(tax_benefit_system): +def test_inconsistent_input(tax_benefit_system) -> None: input_yaml = """ salary: 2016-10: [12000, 20000] @@ -495,5 +695,8 @@ def test_inconsistent_input(tax_benefit_system): 2016-10: [100, 200, 300] """ with pytest.raises(ValueError) as error: - SimulationBuilder().build_from_dict(tax_benefit_system, test_runner.yaml.safe_load(input_yaml)) + SimulationBuilder().build_from_dict( + tax_benefit_system, + test_runner.yaml.safe_load(input_yaml), + ) assert "its length is 3 while there are 2" in error.value.args[0] diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index d5f3ac8008..7f4897e776 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -1,46 +1,47 @@ +import pytest + from openfisca_country_template.situation_examples import single +from openfisca_core import errors, periods from openfisca_core.simulations import SimulationBuilder -def test_calculate_full_tracer(tax_benefit_system): +def test_calculate_full_tracer(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) simulation.trace = True - simulation.calculate('income_tax', '2017-01') + simulation.calculate("income_tax", "2017-01") income_tax_node = simulation.tracer.trees[0] - assert income_tax_node.name == 'income_tax' - assert str(income_tax_node.period) == '2017-01' + assert income_tax_node.name == "income_tax" + assert str(income_tax_node.period) == "2017-01" assert income_tax_node.value == 0 salary_node = income_tax_node.children[0] - assert salary_node.name == 'salary' - assert str(salary_node.period) == '2017-01' + assert salary_node.name == "salary" + assert str(salary_node.period) == "2017-01" assert salary_node.parameters == [] assert len(income_tax_node.parameters) == 1 - assert income_tax_node.parameters[0].name == 'taxes.income_tax_rate' - assert income_tax_node.parameters[0].period == '2017-01-01' + assert income_tax_node.parameters[0].name == "taxes.income_tax_rate" + assert income_tax_node.parameters[0].period == "2017-01-01" assert income_tax_node.parameters[0].value == 0.15 -def test_get_entity_not_found(tax_benefit_system): +def test_get_entity_not_found(tax_benefit_system) -> None: simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) - assert simulation.get_entity(plural = "no_such_entities") is None - - -def test_clone(tax_benefit_system): - simulation = SimulationBuilder().build_from_entities(tax_benefit_system, - { - "persons": { - "bill": {"salary": {"2017-01": 3000}}, - }, - "households": { - "household": { - "parents": ["bill"] - } - } - }) + assert simulation.get_entity(plural="no_such_entities") is None + + +def test_clone(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_from_entities( + tax_benefit_system, + { + "persons": { + "bill": {"salary": {"2017-01": 3000}}, + }, + "households": {"household": {"parents": ["bill"]}}, + }, + ) simulation_clone = simulation.clone() assert simulation != simulation_clone @@ -50,17 +51,31 @@ def test_clone(tax_benefit_system): assert simulation.persons != simulation_clone.persons - salary_holder = simulation.person.get_holder('salary') - salary_holder_clone = simulation_clone.person.get_holder('salary') + salary_holder = simulation.person.get_holder("salary") + salary_holder_clone = simulation_clone.person.get_holder("salary") assert salary_holder != salary_holder_clone assert salary_holder_clone.simulation == simulation_clone assert salary_holder_clone.population == simulation_clone.persons -def test_get_memory_usage(tax_benefit_system): +def test_get_memory_usage(tax_benefit_system) -> None: simulation = SimulationBuilder().build_from_entities(tax_benefit_system, single) - simulation.calculate('disposable_income', '2017-01') - memory_usage = simulation.get_memory_usage(variables = ['salary']) - assert(memory_usage['total_nb_bytes'] > 0) - assert(len(memory_usage['by_variable']) == 1) + simulation.calculate("disposable_income", "2017-01") + memory_usage = simulation.get_memory_usage(variables=["salary"]) + assert memory_usage["total_nb_bytes"] > 0 + assert len(memory_usage["by_variable"]) == 1 + + +def test_invalidate_cache_when_spiral_error_detected(tax_benefit_system) -> None: + simulation = SimulationBuilder().build_default_simulation(tax_benefit_system) + tracer = simulation.tracer + + tracer.record_calculation_start("a", periods.period(2017)) + tracer.record_calculation_start("b", periods.period(2016)) + tracer.record_calculation_start("a", periods.period(2016)) + + with pytest.raises(errors.SpiralError): + simulation._check_for_cycle("a", periods.period(2016)) + + assert len(simulation.invalidated_caches) == 3 diff --git a/tests/core/test_tracers.py b/tests/core/test_tracers.py index ef08992b3c..178b957ec4 100644 --- a/tests/core/test_tracers.py +++ b/tests/core/test_tracers.py @@ -1,43 +1,51 @@ -# -*- coding: utf-8 -*- - +import csv import json import os -import csv -import numpy as np -from pytest import fixture, mark, raises, approx -from openfisca_core.simulations import Simulation, CycleError, SpiralError -from openfisca_core.tracers import SimpleTracer, FullTracer, TracingParameterNodeAtInstant, TraceNode +import numpy +from pytest import approx, fixture, mark, raises + from openfisca_country_template.variables.housing import HousingOccupancyStatus + +from openfisca_core import periods +from openfisca_core.simulations import CycleError, Simulation, SpiralError +from openfisca_core.tracers import ( + FullTracer, + SimpleTracer, + TraceNode, + TracingParameterNodeAtInstant, +) + from .parameters_fancy_indexing.test_fancy_indexing import parameters -class StubSimulation(Simulation): +class TestException(Exception): ... + - def __init__(self): +class StubSimulation(Simulation): + def __init__(self) -> None: self.exception = None self.max_spiral_loops = 1 - def _calculate(self, variable, period): + def _calculate(self, variable, period) -> None: if self.exception: raise self.exception - def invalidate_cache_entry(self, variable, period): + def invalidate_cache_entry(self, variable, period) -> None: pass - def purge_cache_of_invalid_values(self): + def purge_cache_of_invalid_values(self) -> None: pass class MockTracer: - - def record_calculation_start(self, variable, period): + def record_calculation_start(self, variable, period) -> None: self.calculation_start_recorded = True - def record_calculation_result(self, value): + def record_calculation_result(self, value) -> None: self.recorded_result = True - def record_calculation_end(self): + def record_calculation_end(self) -> None: self.calculation_end_recorded = True @@ -47,113 +55,134 @@ def tracer(): @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_one_level(tracer): - tracer.record_calculation_start('a', 2017) +def test_stack_one_level(tracer) -> None: + tracer.record_calculation_start("a", 2017) + assert len(tracer.stack) == 1 - assert tracer.stack == [{'name': 'a', 'period': 2017}] + assert tracer.stack == [{"name": "a", "period": 2017}] tracer.record_calculation_end() + assert tracer.stack == [] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_stack_two_levels(tracer): - tracer.record_calculation_start('a', 2017) - tracer.record_calculation_start('b', 2017) +def test_stack_two_levels(tracer) -> None: + tracer.record_calculation_start("a", 2017) + tracer.record_calculation_start("b", 2017) + assert len(tracer.stack) == 2 - assert tracer.stack == [{'name': 'a', 'period': 2017}, {'name': 'b', 'period': 2017}] + assert tracer.stack == [ + {"name": "a", "period": 2017}, + {"name": "b", "period": 2017}, + ] tracer.record_calculation_end() + assert len(tracer.stack) == 1 - assert tracer.stack == [{'name': 'a', 'period': 2017}] + assert tracer.stack == [{"name": "a", "period": 2017}] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_tracer_contract(tracer): +def test_tracer_contract(tracer) -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.calculate('a', 2017) + simulation.calculate("a", 2017) assert simulation.tracer.calculation_start_recorded assert simulation.tracer.calculation_end_recorded -def test_exception_robustness(): +def test_exception_robustness() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.exception = Exception(":-o") + simulation.exception = TestException(":-o") - with raises(Exception): - simulation.calculate('a', 2017) + with raises(TestException): + simulation.calculate("a", 2017) assert simulation.tracer.calculation_start_recorded assert simulation.tracer.calculation_end_recorded @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_cycle_error(tracer): +def test_cycle_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer - tracer.record_calculation_start('a', 2017) - simulation._check_for_cycle('a', 2017) - tracer.record_calculation_start('a', 2017) + tracer.record_calculation_start("a", 2017) + + assert not simulation._check_for_cycle("a", 2017) + + tracer.record_calculation_start("a", 2017) + with raises(CycleError): - simulation._check_for_cycle('a', 2017) + simulation._check_for_cycle("a", 2017) + + assert len(tracer.stack) == 2 + assert tracer.stack == [ + {"name": "a", "period": 2017}, + {"name": "a", "period": 2017}, + ] @mark.parametrize("tracer", [SimpleTracer(), FullTracer()]) -def test_spiral_error(tracer): +def test_spiral_error(tracer) -> None: simulation = StubSimulation() simulation.tracer = tracer - tracer.record_calculation_start('a', 2017) - tracer.record_calculation_start('a', 2016) - tracer.record_calculation_start('a', 2015) + + tracer.record_calculation_start("a", periods.period(2017)) + tracer.record_calculation_start("b", periods.period(2016)) + tracer.record_calculation_start("a", periods.period(2016)) with raises(SpiralError): - simulation._check_for_cycle('a', 2015) + simulation._check_for_cycle("a", periods.period(2016)) + + assert len(tracer.stack) == 3 + assert tracer.stack == [ + {"name": "a", "period": periods.period(2017)}, + {"name": "b", "period": periods.period(2016)}, + {"name": "a", "period": periods.period(2016)}, + ] -def test_full_tracer_one_calculation(tracer): - tracer._enter_calculation('a', 2017) +def test_full_tracer_one_calculation(tracer) -> None: + tracer._enter_calculation("a", 2017) tracer._exit_calculation() + assert tracer.stack == [] assert len(tracer.trees) == 1 - assert tracer.trees[0].name == 'a' + assert tracer.trees[0].name == "a" assert tracer.trees[0].period == 2017 assert tracer.trees[0].children == [] -def test_full_tracer_2_branches(tracer): - tracer._enter_calculation('a', 2017) - - tracer._enter_calculation('b', 2017) +def test_full_tracer_2_branches(tracer) -> None: + tracer._enter_calculation("a", 2017) + tracer._enter_calculation("b", 2017) tracer._exit_calculation() - - tracer._enter_calculation('c', 2017) + tracer._enter_calculation("c", 2017) tracer._exit_calculation() - tracer._exit_calculation() assert len(tracer.trees) == 1 assert len(tracer.trees[0].children) == 2 -def test_full_tracer_2_trees(tracer): - tracer._enter_calculation('b', 2017) +def test_full_tracer_2_trees(tracer) -> None: + tracer._enter_calculation("b", 2017) tracer._exit_calculation() - - tracer._enter_calculation('c', 2017) + tracer._enter_calculation("c", 2017) tracer._exit_calculation() assert len(tracer.trees) == 2 -def test_full_tracer_3_generations(tracer): - tracer._enter_calculation('a', 2017) - tracer._enter_calculation('b', 2017) - tracer._enter_calculation('c', 2017) +def test_full_tracer_3_generations(tracer) -> None: + tracer._enter_calculation("a", 2017) + tracer._enter_calculation("b", 2017) + tracer._enter_calculation("c", 2017) tracer._exit_calculation() tracer._exit_calculation() tracer._exit_calculation() @@ -163,117 +192,118 @@ def test_full_tracer_3_generations(tracer): assert len(tracer.trees[0].children[0].children) == 1 -def test_full_tracer_variable_nb_requests(tracer): - tracer._enter_calculation('a', '2017-01') - tracer._enter_calculation('a', '2017-02') +def test_full_tracer_variable_nb_requests(tracer) -> None: + tracer._enter_calculation("a", "2017-01") + tracer._enter_calculation("a", "2017-02") - assert tracer.get_nb_requests('a') == 2 + assert tracer.get_nb_requests("a") == 2 -def test_simulation_calls_record_calculation_result(): +def test_simulation_calls_record_calculation_result() -> None: simulation = StubSimulation() simulation.tracer = MockTracer() - simulation.calculate('a', 2017) + simulation.calculate("a", 2017) assert simulation.tracer.recorded_result -def test_record_calculation_result(tracer): - tracer._enter_calculation('a', 2017) - tracer.record_calculation_result(np.asarray(100)) +def test_record_calculation_result(tracer) -> None: + tracer._enter_calculation("a", 2017) + tracer.record_calculation_result(numpy.asarray(100)) tracer._exit_calculation() assert tracer.trees[0].value == 100 -def test_flat_trace(tracer): - tracer._enter_calculation('a', 2019) - tracer._enter_calculation('b', 2019) +def test_flat_trace(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer._enter_calculation("b", 2019) tracer._exit_calculation() tracer._exit_calculation() trace = tracer.get_flat_trace() assert len(trace) == 2 - assert trace['a<2019>']['dependencies'] == ['b<2019>'] - assert trace['b<2019>']['dependencies'] == [] + assert trace["a<2019>"]["dependencies"] == ["b<2019>"] + assert trace["b<2019>"]["dependencies"] == [] -def test_flat_trace_serialize_vectorial_values(tracer): - tracer._enter_calculation('a', 2019) - tracer.record_parameter_access('x.y.z', 2019, np.asarray([100, 200, 300])) - tracer.record_calculation_result(np.asarray([10, 20, 30])) +def test_flat_trace_serialize_vectorial_values(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer.record_parameter_access("x.y.z", 2019, numpy.asarray([100, 200, 300])) + tracer.record_calculation_result(numpy.asarray([10, 20, 30])) tracer._exit_calculation() trace = tracer.get_serialized_flat_trace() - assert json.dumps(trace['a<2019>']['value']) - assert json.dumps(trace['a<2019>']['parameters']['x.y.z<2019>']) + assert json.dumps(trace["a<2019>"]["value"]) + assert json.dumps(trace["a<2019>"]["parameters"]["x.y.z<2019>"]) -def test_flat_trace_with_parameter(tracer): - tracer._enter_calculation('a', 2019) - tracer.record_parameter_access('p', '2019-01-01', 100) +def test_flat_trace_with_parameter(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer.record_parameter_access("p", "2019-01-01", 100) tracer._exit_calculation() trace = tracer.get_flat_trace() assert len(trace) == 1 - assert trace['a<2019>']['parameters'] == {'p<2019-01-01>': 100} + assert trace["a<2019>"]["parameters"] == {"p<2019-01-01>": 100} -def test_flat_trace_with_cache(tracer): - tracer._enter_calculation('a', 2019) - tracer._enter_calculation('b', 2019) - tracer._enter_calculation('c', 2019) +def test_flat_trace_with_cache(tracer) -> None: + tracer._enter_calculation("a", 2019) + tracer._enter_calculation("b", 2019) + tracer._enter_calculation("c", 2019) tracer._exit_calculation() tracer._exit_calculation() tracer._exit_calculation() - tracer._enter_calculation('b', 2019) + tracer._enter_calculation("b", 2019) tracer._exit_calculation() trace = tracer.get_flat_trace() - assert trace['b<2019>']['dependencies'] == ['c<2019>'] + assert trace["b<2019>"]["dependencies"] == ["c<2019>"] -def test_calculation_time(): +def test_calculation_time() -> None: tracer = FullTracer() - tracer._enter_calculation('a', 2019) + tracer._enter_calculation("a", 2019) tracer._record_start_time(1500) tracer._record_end_time(2500) tracer._exit_calculation() - performance_json = tracer.performance_log._json() - assert performance_json['name'] == 'All calculations' - assert performance_json['value'] == 1000 - simulation_children = performance_json['children'] - assert simulation_children[0]['name'] == 'a<2019>' - assert simulation_children[0]['value'] == 1000 + assert performance_json["name"] == "All calculations" + assert performance_json["value"] == 1000 + + simulation_children = performance_json["children"] + + assert simulation_children[0]["name"] == "a<2019>" + assert simulation_children[0]["value"] == 1000 @fixture def tracer_calc_time(): tracer = FullTracer() - tracer._enter_calculation('a', 2019) + tracer._enter_calculation("a", 2019) tracer._record_start_time(1500) - tracer._enter_calculation('b', 2019) + tracer._enter_calculation("b", 2019) tracer._record_start_time(1600) tracer._record_end_time(2300) tracer._exit_calculation() - tracer._enter_calculation('c', 2019) + tracer._enter_calculation("c", 2019) tracer._record_start_time(2300) tracer._record_end_time(2400) tracer._exit_calculation() # Cache call - tracer._enter_calculation('c', 2019) + tracer._enter_calculation("c", 2019) tracer._record_start_time(2400) tracer._record_end_time(2410) tracer._exit_calculation() @@ -281,7 +311,7 @@ def tracer_calc_time(): tracer._record_end_time(2500) tracer._exit_calculation() - tracer._enter_calculation('a', 2018) + tracer._enter_calculation("a", 2018) tracer._record_start_time(1800) tracer._record_end_time(1800 + 200) tracer._exit_calculation() @@ -289,185 +319,243 @@ def tracer_calc_time(): return tracer -def test_calculation_time_with_depth(tracer_calc_time): +def test_calculation_time_with_depth(tracer_calc_time) -> None: tracer = tracer_calc_time performance_json = tracer.performance_log._json() - simulation_grand_children = performance_json['children'][0]['children'] + simulation_grand_children = performance_json["children"][0]["children"] - assert simulation_grand_children[0]['name'] == 'b<2019>' - assert simulation_grand_children[0]['value'] == 700 + assert simulation_grand_children[0]["name"] == "b<2019>" + assert simulation_grand_children[0]["value"] == 700 -def test_flat_trace_calc_time(tracer_calc_time): +def test_flat_trace_calc_time(tracer_calc_time) -> None: tracer = tracer_calc_time flat_trace = tracer.get_flat_trace() - assert flat_trace['a<2019>']['calculation_time'] == 1000 - assert flat_trace['b<2019>']['calculation_time'] == 700 - assert flat_trace['c<2019>']['calculation_time'] == 100 - assert flat_trace['a<2019>']['formula_time'] == 190 # 1000 - 700 - 100 - 10 - assert flat_trace['b<2019>']['formula_time'] == 700 - assert flat_trace['c<2019>']['formula_time'] == 100 + assert flat_trace["a<2019>"]["calculation_time"] == 1000 + assert flat_trace["b<2019>"]["calculation_time"] == 700 + assert flat_trace["c<2019>"]["calculation_time"] == 100 + assert flat_trace["a<2019>"]["formula_time"] == 190 # 1000 - 700 - 100 - 10 + assert flat_trace["b<2019>"]["formula_time"] == 700 + assert flat_trace["c<2019>"]["formula_time"] == 100 -def test_generate_performance_table(tracer_calc_time, tmpdir): +def test_generate_performance_table(tracer_calc_time, tmpdir) -> None: tracer = tracer_calc_time tracer.generate_performance_tables(tmpdir) - with open(os.path.join(tmpdir, 'performance_table.csv'), 'r') as csv_file: + + with open(os.path.join(tmpdir, "performance_table.csv")) as csv_file: csv_reader = csv.DictReader(csv_file) csv_rows = list(csv_reader) + assert len(csv_rows) == 4 - a_row = next(row for row in csv_rows if row['name'] == 'a<2019>') - assert float(a_row['calculation_time']) == 1000 - assert float(a_row['formula_time']) == 190 - with open(os.path.join(tmpdir, 'aggregated_performance_table.csv'), 'r') as csv_file: + a_row = next(row for row in csv_rows if row["name"] == "a<2019>") + + assert float(a_row["calculation_time"]) == 1000 + assert float(a_row["formula_time"]) == 190 + + with open(os.path.join(tmpdir, "aggregated_performance_table.csv")) as csv_file: aggregated_csv_reader = csv.DictReader(csv_file) aggregated_csv_rows = list(aggregated_csv_reader) + assert len(aggregated_csv_rows) == 3 - a_row = next(row for row in aggregated_csv_rows if row['name'] == 'a') - assert float(a_row['calculation_time']) == 1000 + 200 - assert float(a_row['formula_time']) == 190 + 200 + a_row = next(row for row in aggregated_csv_rows if row["name"] == "a") -def test_get_aggregated_calculation_times(tracer_calc_time): - perf_log = tracer_calc_time.performance_log - aggregated_calculation_times = perf_log.aggregate_calculation_times(tracer_calc_time.get_flat_trace()) + assert float(a_row["calculation_time"]) == 1000 + 200 + assert float(a_row["formula_time"]) == 190 + 200 - assert aggregated_calculation_times['a']['calculation_time'] == 1000 + 200 - assert aggregated_calculation_times['a']['formula_time'] == 190 + 200 - assert aggregated_calculation_times['a']['avg_calculation_time'] == (1000 + 200) / 2 - assert aggregated_calculation_times['a']['avg_formula_time'] == (190 + 200) / 2 +def test_get_aggregated_calculation_times(tracer_calc_time) -> None: + perf_log = tracer_calc_time.performance_log + aggregated_calculation_times = perf_log.aggregate_calculation_times( + tracer_calc_time.get_flat_trace(), + ) + + assert aggregated_calculation_times["a"]["calculation_time"] == 1000 + 200 + assert aggregated_calculation_times["a"]["formula_time"] == 190 + 200 + assert aggregated_calculation_times["a"]["avg_calculation_time"] == (1000 + 200) / 2 + assert aggregated_calculation_times["a"]["avg_formula_time"] == (190 + 200) / 2 -def test_rounding(): - node_a = TraceNode('a', 2017) +def test_rounding() -> None: + node_a = TraceNode("a", 2017) node_a.start = 1.23456789 node_a.end = node_a.start + 1.23456789e-03 assert node_a.calculation_time() == 1.235e-03 # Keep only 3 significant figures - node_b = TraceNode('b', 2017) + node_b = TraceNode("b", 2017) node_b.start = node_a.start node_b.end = node_a.end - 1.23456789e-08 node_a.children = [node_b] - assert node_a.formula_time() == 1.235e-08 # The rounding should not prevent from calculating a precise formula_time + assert ( + node_a.formula_time() == 1.235e-08 + ) # The rounding should not prevent from calculating a precise formula_time -def test_variable_stats(tracer): +def test_variable_stats(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("B", 2017) tracer._enter_calculation("B", 2016) - assert tracer.get_nb_requests('B') == 3 - assert tracer.get_nb_requests('A') == 1 - assert tracer.get_nb_requests('C') == 0 + assert tracer.get_nb_requests("B") == 3 + assert tracer.get_nb_requests("A") == 1 + assert tracer.get_nb_requests("C") == 0 -def test_log_format(tracer): +def test_log_format(tracer) -> None: tracer._enter_calculation("A", 2017) tracer._enter_calculation("B", 2017) - tracer.record_calculation_result(np.asarray([1])) + tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() - tracer.record_calculation_result(np.asarray([2])) + tracer.record_calculation_result(numpy.asarray([2])) tracer._exit_calculation() - lines = tracer.computation_log.lines() - assert lines[0] == ' A<2017> >> [2]' - assert lines[1] == ' B<2017> >> [1]' + + assert lines[0] == " A<2017> >> [2]" + assert lines[1] == " B<2017> >> [1]" -def test_log_format_forest(tracer): +def test_log_format_forest(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(np.asarray([1])) + tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() - tracer._enter_calculation("B", 2017) - tracer.record_calculation_result(np.asarray([2])) + tracer.record_calculation_result(numpy.asarray([2])) tracer._exit_calculation() - lines = tracer.computation_log.lines() - assert lines[0] == ' A<2017> >> [1]' - assert lines[1] == ' B<2017> >> [2]' + assert lines[0] == " A<2017> >> [1]" + assert lines[1] == " B<2017> >> [2]" -def test_log_aggregate(tracer): + +def test_log_aggregate(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(np.asarray([1])) + tracer.record_calculation_result(numpy.asarray([1])) tracer._exit_calculation() + lines = tracer.computation_log.lines(aggregate=True) - lines = tracer.computation_log.lines(aggregate = True) assert lines[0] == " A<2017> >> {'avg': 1.0, 'max': 1, 'min': 1}" -def test_log_aggregate_with_enum(tracer): +def test_log_aggregate_with_enum(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(HousingOccupancyStatus.encode(np.repeat('tenant', 100))) + tracer.record_calculation_result( + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), + ) tracer._exit_calculation() + lines = tracer.computation_log.lines(aggregate=True) - lines = tracer.computation_log.lines(aggregate = True) - assert lines[0] == " A<2017> >> {'avg': , 'max': , 'min': }" + assert ( + lines[0] + == " A<2017> >> {'avg': EnumArray(HousingOccupancyStatus.tenant), 'max': EnumArray(HousingOccupancyStatus.tenant), 'min': EnumArray(HousingOccupancyStatus.tenant)}" + ) -def test_log_aggregate_with_strings(tracer): +def test_log_aggregate_with_strings(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(np.repeat('foo', 100)) + tracer.record_calculation_result(numpy.repeat("foo", 100)) tracer._exit_calculation() + lines = tracer.computation_log.lines(aggregate=True) - lines = tracer.computation_log.lines(aggregate = True) assert lines[0] == " A<2017> >> {'avg': '?', 'max': '?', 'min': '?'}" -def test_no_wrapping(tracer): +def test_log_max_depth(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(HousingOccupancyStatus.encode(np.repeat('tenant', 100))) + tracer._enter_calculation("B", 2017) + tracer._enter_calculation("C", 2017) + tracer.record_calculation_result(numpy.asarray([3])) tracer._exit_calculation() + tracer.record_calculation_result(numpy.asarray([2])) + tracer._exit_calculation() + tracer.record_calculation_result(numpy.asarray([1])) + tracer._exit_calculation() + + assert len(tracer.computation_log.lines()) == 3 + assert len(tracer.computation_log.lines(max_depth=4)) == 3 + assert len(tracer.computation_log.lines(max_depth=3)) == 3 + assert len(tracer.computation_log.lines(max_depth=2)) == 2 + assert len(tracer.computation_log.lines(max_depth=1)) == 1 + assert len(tracer.computation_log.lines(max_depth=0)) == 0 + +def test_no_wrapping(tracer) -> None: + tracer._enter_calculation("A", 2017) + tracer.record_calculation_result( + HousingOccupancyStatus.encode(numpy.repeat("tenant", 100)), + ) + tracer._exit_calculation() lines = tracer.computation_log.lines() + assert "'tenant'" in lines[0] assert "\n" not in lines[0] -def test_trace_enums(tracer): +def test_trace_enums(tracer) -> None: tracer._enter_calculation("A", 2017) - tracer.record_calculation_result(HousingOccupancyStatus.encode(np.array(['tenant']))) + tracer.record_calculation_result( + HousingOccupancyStatus.encode(numpy.array(["tenant"])), + ) tracer._exit_calculation() - lines = tracer.computation_log.lines() + assert lines[0] == " A<2017> >> ['tenant']" # Tests on tracing with fancy indexing -zone = np.asarray(['z1', 'z2', 'z2', 'z1']) -housing_occupancy_status = np.asarray(['owner', 'owner', 'tenant', 'tenant']) -family_status = np.asarray(['single', 'couple', 'single', 'couple']) +zone = numpy.asarray(["z1", "z2", "z2", "z1"]) +housing_occupancy_status = numpy.asarray(["owner", "owner", "tenant", "tenant"]) +family_status = numpy.asarray(["single", "couple", "single", "couple"]) -def check_tracing_params(accessor, param_key): +def check_tracing_params(accessor, param_key) -> None: tracer = FullTracer() - tracer._enter_calculation('A', '2015-01') - tracingParams = TracingParameterNodeAtInstant(parameters('2015-01-01'), tracer) + + tracer._enter_calculation("A", "2015-01") + + tracingParams = TracingParameterNodeAtInstant(parameters("2015-01-01"), tracer) param = accessor(tracingParams) + assert tracer.trees[0].parameters[0].name == param_key assert tracer.trees[0].parameters[0].value == approx(param) -@mark.parametrize("test", [ - (lambda P: P.rate.single.owner.z1, 'rate.single.owner.z1'), # basic case - (lambda P: P.rate.single.owner[zone], 'rate.single.owner'), # fancy indexing on leaf - (lambda P: P.rate.single[housing_occupancy_status].z1, 'rate.single'), # on a node - (lambda P: P.rate.single[housing_occupancy_status][zone], 'rate.single'), # double fancy indexing - (lambda P: P.rate[family_status][housing_occupancy_status].z2, 'rate'), # double + node - (lambda P: P.rate[family_status][housing_occupancy_status][zone], 'rate'), # triple - ]) -def test_parameters(test): +@mark.parametrize( + "test", + [ + (lambda P: P.rate.single.owner.z1, "rate.single.owner.z1"), # basic case + ( + lambda P: P.rate.single.owner[zone], + "rate.single.owner", + ), # fancy indexing on leaf + ( + lambda P: P.rate.single[housing_occupancy_status].z1, + "rate.single", + ), # on a node + ( + lambda P: P.rate.single[housing_occupancy_status][zone], + "rate.single", + ), # double fancy indexing + ( + lambda P: P.rate[family_status][housing_occupancy_status].z2, + "rate", + ), # double + node + ( + lambda P: P.rate[family_status][housing_occupancy_status][zone], + "rate", + ), # triple + ], +) +def test_parameters(test) -> None: check_tracing_params(*test) -def test_browse_trace(): +def test_browse_trace() -> None: tracer = FullTracer() tracer._enter_calculation("B", 2017) @@ -480,6 +568,6 @@ def test_browse_trace(): tracer._enter_calculation("F", 2017) tracer._exit_calculation() tracer._exit_calculation() - browsed_nodes = [node.name for node in tracer.browse_trace()] - assert browsed_nodes == ['B', 'C', 'D', 'E', 'F'] + + assert browsed_nodes == ["B", "C", "D", "E", "F"] diff --git a/tests/core/test_yaml.py b/tests/core/test_yaml.py index f63e37ff39..1672ea3453 100644 --- a/tests/core/test_yaml.py +++ b/tests/core/test_yaml.py @@ -2,10 +2,10 @@ import subprocess import pytest + import openfisca_extension_template from openfisca_core.tools.test_runner import run_tests - from tests.fixtures import yaml_tests yaml_tests_dir = os.path.dirname(yaml_tests.__file__) @@ -13,93 +13,121 @@ EXIT_TESTSFAILED = 1 -def run_yaml_test(tax_benefit_system, path, options = None): +def run_yaml_test(tax_benefit_system, path, options=None): yaml_path = os.path.join(yaml_tests_dir, path) if options is None: options = {} - result = run_tests(tax_benefit_system, yaml_path, options) - return result + return run_tests(tax_benefit_system, yaml_path, options) -def test_success(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_success.yml') == EXIT_OK +def test_success(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_success.yml") == EXIT_OK -def test_fail(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_failure.yaml') == EXIT_TESTSFAILED +def test_fail(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_failure.yaml") == EXIT_TESTSFAILED -def test_relative_error_margin_success(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_relative_error_margin.yaml') == EXIT_OK +def test_relative_error_margin_success(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "test_relative_error_margin.yaml") == EXIT_OK + ) -def test_relative_error_margin_fail(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'failing_test_relative_error_margin.yaml') == EXIT_TESTSFAILED +def test_relative_error_margin_fail(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "failing_test_relative_error_margin.yaml") + == EXIT_TESTSFAILED + ) -def test_absolute_error_margin_success(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_absolute_error_margin.yaml') == EXIT_OK +def test_absolute_error_margin_success(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "test_absolute_error_margin.yaml") == EXIT_OK + ) -def test_absolute_error_margin_fail(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'failing_test_absolute_error_margin.yaml') == EXIT_TESTSFAILED +def test_absolute_error_margin_fail(tax_benefit_system) -> None: + assert ( + run_yaml_test(tax_benefit_system, "failing_test_absolute_error_margin.yaml") + == EXIT_TESTSFAILED + ) -def test_run_tests_from_directory(tax_benefit_system): - dir_path = os.path.join(yaml_tests_dir, 'directory') +def test_run_tests_from_directory(tax_benefit_system) -> None: + dir_path = os.path.join(yaml_tests_dir, "directory") assert run_yaml_test(tax_benefit_system, dir_path) == EXIT_OK -def test_with_reform(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_with_reform.yaml') == EXIT_OK +def test_with_reform(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_with_reform.yaml") == EXIT_OK -def test_with_extension(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_with_extension.yaml') == EXIT_OK +def test_with_extension(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_with_extension.yaml") == EXIT_OK -def test_with_anchors(tax_benefit_system): - assert run_yaml_test(tax_benefit_system, 'test_with_anchors.yaml') == EXIT_OK +def test_with_anchors(tax_benefit_system) -> None: + assert run_yaml_test(tax_benefit_system, "test_with_anchors.yaml") == EXIT_OK -def test_run_tests_from_directory_fail(tax_benefit_system): +def test_run_tests_from_directory_fail(tax_benefit_system) -> None: assert run_yaml_test(tax_benefit_system, yaml_tests_dir) == EXIT_TESTSFAILED -def test_name_filter(tax_benefit_system): - assert run_yaml_test( - tax_benefit_system, - yaml_tests_dir, - options = {'name_filter': 'success'} - ) == EXIT_OK +def test_name_filter(tax_benefit_system) -> None: + assert ( + run_yaml_test( + tax_benefit_system, + yaml_tests_dir, + options={"name_filter": "success"}, + ) + == EXIT_OK + ) -def test_shell_script(): - yaml_path = os.path.join(yaml_tests_dir, 'test_success.yml') - command = ['openfisca', 'test', yaml_path, '-c', 'openfisca_country_template'] - with open(os.devnull, 'wb') as devnull: - subprocess.check_call(command, stdout = devnull, stderr = devnull) +def test_shell_script() -> None: + yaml_path = os.path.join(yaml_tests_dir, "test_success.yml") + command = ["openfisca", "test", yaml_path, "-c", "openfisca_country_template"] + with open(os.devnull, "wb") as devnull: + subprocess.check_call(command, stdout=devnull, stderr=devnull) -def test_failing_shell_script(): - yaml_path = os.path.join(yaml_tests_dir, 'test_failure.yaml') - command = ['openfisca', 'test', yaml_path, '-c', 'openfisca_dummy_country'] - with open(os.devnull, 'wb') as devnull: +def test_failing_shell_script() -> None: + yaml_path = os.path.join(yaml_tests_dir, "test_failure.yaml") + command = ["openfisca", "test", yaml_path, "-c", "openfisca_dummy_country"] + with open(os.devnull, "wb") as devnull: with pytest.raises(subprocess.CalledProcessError): - subprocess.check_call(command, stdout = devnull, stderr = devnull) - - -def test_shell_script_with_reform(): - yaml_path = os.path.join(yaml_tests_dir, 'test_with_reform_2.yaml') - command = ['openfisca', 'test', yaml_path, '-c', 'openfisca_country_template', '-r', 'openfisca_country_template.reforms.removal_basic_income.removal_basic_income'] - with open(os.devnull, 'wb') as devnull: - subprocess.check_call(command, stdout = devnull, stderr = devnull) - - -def test_shell_script_with_extension(): - tests_dir = os.path.join(openfisca_extension_template.__path__[0], 'tests') - command = ['openfisca', 'test', tests_dir, '-c', 'openfisca_country_template', '-e', 'openfisca_extension_template'] - with open(os.devnull, 'wb') as devnull: - subprocess.check_call(command, stdout = devnull, stderr = devnull) + subprocess.check_call(command, stdout=devnull, stderr=devnull) + + +def test_shell_script_with_reform() -> None: + yaml_path = os.path.join(yaml_tests_dir, "test_with_reform_2.yaml") + command = [ + "openfisca", + "test", + yaml_path, + "-c", + "openfisca_country_template", + "-r", + "openfisca_country_template.reforms.removal_basic_income.removal_basic_income", + ] + with open(os.devnull, "wb") as devnull: + subprocess.check_call(command, stdout=devnull, stderr=devnull) + + +def test_shell_script_with_extension() -> None: + tests_dir = os.path.join(openfisca_extension_template.__path__[0], "tests") + command = [ + "openfisca", + "test", + tests_dir, + "-c", + "openfisca_country_template", + "-e", + "openfisca_extension_template", + ] + with open(os.devnull, "wb") as devnull: + subprocess.check_call(command, stdout=devnull, stderr=devnull) diff --git a/tests/core/tools/test_assert_near.py b/tests/core/tools/test_assert_near.py index eecf9d1d1f..c351be0f9c 100644 --- a/tests/core/tools/test_assert_near.py +++ b/tests/core/tools/test_assert_near.py @@ -1,21 +1,25 @@ -import numpy as np +import numpy from openfisca_core.tools import assert_near -def test_date(): - assert_near(np.array("2012-03-24", dtype = 'datetime64[D]'), "2012-03-24") +def test_date() -> None: + assert_near(numpy.array("2012-03-24", dtype="datetime64[D]"), "2012-03-24") -def test_enum(tax_benefit_system): - possible_values = tax_benefit_system.variables['housing_occupancy_status'].possible_values - value = possible_values.encode(np.array(['tenant'])) - expected_value = 'tenant' +def test_enum(tax_benefit_system) -> None: + possible_values = tax_benefit_system.variables[ + "housing_occupancy_status" + ].possible_values + value = possible_values.encode(numpy.array(["tenant"])) + expected_value = "tenant" assert_near(value, expected_value) -def test_enum_2(tax_benefit_system): - possible_values = tax_benefit_system.variables['housing_occupancy_status'].possible_values - value = possible_values.encode(np.array(['tenant', 'owner'])) - expected_value = ['tenant', 'owner'] +def test_enum_2(tax_benefit_system) -> None: + possible_values = tax_benefit_system.variables[ + "housing_occupancy_status" + ].possible_values + value = possible_values.encode(numpy.array(["tenant", "owner"])) + expected_value = ["tenant", "owner"] assert_near(value, expected_value) diff --git a/tests/core/tools/test_runner/test_yaml_runner.py b/tests/core/tools/test_runner/test_yaml_runner.py index 8f0b9c2bbb..6a02d14cef 100644 --- a/tests/core/tools/test_runner/test_yaml_runner.py +++ b/tests/core/tools/test_runner/test_yaml_runner.py @@ -1,22 +1,21 @@ import os -from typing import List +import numpy import pytest -import numpy as np -from openfisca_core.tools.test_runner import _get_tax_benefit_system, YamlItem, YamlFile -from openfisca_core.errors import VariableNotFound -from openfisca_core.variables import Variable -from openfisca_core.populations import Population +from openfisca_core import errors from openfisca_core.entities import Entity -from openfisca_core.periods import ETERNITY +from openfisca_core.periods import DateUnit +from openfisca_core.populations import Population +from openfisca_core.tools.test_runner import YamlFile, YamlItem, _get_tax_benefit_system +from openfisca_core.variables import Variable class TaxBenefitSystem: - def __init__(self): - self.variables = {'salary': TestVariable()} - self.person_entity = Entity('person', 'persons', None, "") - self.person_entity.tax_benefit_system = self + def __init__(self) -> None: + self.variables = {"salary": TestVariable()} + self.person_entity = Entity("person", "persons", None, "") + self.person_entity.set_tax_benefit_system(self) def get_package_metadata(self): return {"name": "Test", "version": "Test"} @@ -24,7 +23,7 @@ def get_package_metadata(self): def apply_reform(self, path): return Reform(self) - def load_extension(self, extension): + def load_extension(self, extension) -> None: pass def entities_by_singular(self): @@ -34,9 +33,9 @@ def entities_plural(self): return {} def instantiate_entities(self): - return {'person': Population(self.person_entity)} + return {"person": Population(self.person_entity)} - def get_variable(self, variable_name, check_existence = True): + def get_variable(self, variable_name: str, check_existence=True): return self.variables.get(variable_name) def clone(self): @@ -44,106 +43,118 @@ def clone(self): class Reform(TaxBenefitSystem): - def __init__(self, baseline): + def __init__(self, baseline) -> None: self.baseline = baseline class Simulation: - def __init__(self): + def __init__(self) -> None: self.populations = {"person": None} - def get_population(self, plural = None): + def get_population(self, plural=None) -> None: return None class TestFile(YamlFile): - - def __init__(self): + def __init__(self) -> None: self.config = None self.session = None - self._nodeid = 'testname' + self._nodeid = "testname" class TestItem(YamlItem): - def __init__(self, test): - super().__init__('', TestFile(), TaxBenefitSystem(), test, {}) + def __init__(self, test) -> None: + super().__init__("", TestFile(), TaxBenefitSystem(), test, {}) self.tax_benefit_system = self.baseline_tax_benefit_system self.simulation = Simulation() class TestVariable(Variable): - definition_period = ETERNITY + definition_period = DateUnit.ETERNITY value_type = float - def __init__(self): + def __init__(self) -> None: self.end = None - self.entity = Entity('person', 'persons', None, "") + self.entity = Entity("person", "persons", None, "") self.is_neutralized = False self.set_input = None - self.dtype = np.float32 + self.dtype = numpy.float32 -def test_variable_not_found(): +@pytest.mark.skip(reason="Deprecated node constructor") +def test_variable_not_found() -> None: test = {"output": {"unknown_variable": 0}} - with pytest.raises(VariableNotFound) as excinfo: + with pytest.raises(errors.VariableNotFoundError) as excinfo: test_item = TestItem(test) test_item.check_output() assert excinfo.value.variable_name == "unknown_variable" -def test_tax_benefit_systems_with_reform_cache(): +def test_tax_benefit_systems_with_reform_cache() -> None: baseline = TaxBenefitSystem() - ab_tax_benefit_system = _get_tax_benefit_system(baseline, 'ab', []) - ba_tax_benefit_system = _get_tax_benefit_system(baseline, 'ba', []) + ab_tax_benefit_system = _get_tax_benefit_system(baseline, "ab", []) + ba_tax_benefit_system = _get_tax_benefit_system(baseline, "ba", []) assert ab_tax_benefit_system != ba_tax_benefit_system -def test_reforms_formats(): +def test_reforms_formats() -> None: baseline = TaxBenefitSystem() - lonely_reform_tbs = _get_tax_benefit_system(baseline, 'lonely_reform', []) - list_lonely_reform_tbs = _get_tax_benefit_system(baseline, ['lonely_reform'], []) + lonely_reform_tbs = _get_tax_benefit_system(baseline, "lonely_reform", []) + list_lonely_reform_tbs = _get_tax_benefit_system(baseline, ["lonely_reform"], []) assert lonely_reform_tbs == list_lonely_reform_tbs -def test_reforms_order(): +def test_reforms_order() -> None: baseline = TaxBenefitSystem() - abba_tax_benefit_system = _get_tax_benefit_system(baseline, ['ab', 'ba'], []) - baab_tax_benefit_system = _get_tax_benefit_system(baseline, ['ba', 'ab'], []) - assert abba_tax_benefit_system != baab_tax_benefit_system # keep reforms order in cache + abba_tax_benefit_system = _get_tax_benefit_system(baseline, ["ab", "ba"], []) + baab_tax_benefit_system = _get_tax_benefit_system(baseline, ["ba", "ab"], []) + assert ( + abba_tax_benefit_system != baab_tax_benefit_system + ) # keep reforms order in cache -def test_tax_benefit_systems_with_extensions_cache(): +def test_tax_benefit_systems_with_extensions_cache() -> None: baseline = TaxBenefitSystem() - xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], 'xy') - yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], 'yx') + xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], "xy") + yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], "yx") assert xy_tax_benefit_system != yx_tax_benefit_system -def test_extensions_formats(): +def test_extensions_formats() -> None: baseline = TaxBenefitSystem() - lonely_extension_tbs = _get_tax_benefit_system(baseline, [], 'lonely_extension') - list_lonely_extension_tbs = _get_tax_benefit_system(baseline, [], ['lonely_extension']) + lonely_extension_tbs = _get_tax_benefit_system(baseline, [], "lonely_extension") + list_lonely_extension_tbs = _get_tax_benefit_system( + baseline, + [], + ["lonely_extension"], + ) assert lonely_extension_tbs == list_lonely_extension_tbs -def test_extensions_order(): +def test_extensions_order() -> None: baseline = TaxBenefitSystem() - xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], ['x', 'y']) - yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], ['y', 'x']) - assert xy_tax_benefit_system == yx_tax_benefit_system # extensions order is ignored in cache + xy_tax_benefit_system = _get_tax_benefit_system(baseline, [], ["x", "y"]) + yx_tax_benefit_system = _get_tax_benefit_system(baseline, [], ["y", "x"]) + assert ( + xy_tax_benefit_system == yx_tax_benefit_system + ) # extensions order is ignored in cache -def test_performance_graph_option_output(): - test = {'input': {'salary': {'2017-01': 2000}}, 'output': {'salary': {'2017-01': 2000}}} +@pytest.mark.skip(reason="Deprecated node constructor") +def test_performance_graph_option_output() -> None: + test = { + "input": {"salary": {"2017-01": 2000}}, + "output": {"salary": {"2017-01": 2000}}, + } test_item = TestItem(test) - test_item.options = {'performance_graph': True} + test_item.options = {"performance_graph": True} paths = ["./performance_graph.html"] @@ -158,10 +169,14 @@ def test_performance_graph_option_output(): clean_performance_files(paths) -def test_performance_tables_option_output(): - test = {'input': {'salary': {'2017-01': 2000}}, 'output': {'salary': {'2017-01': 2000}}} +@pytest.mark.skip(reason="Deprecated node constructor") +def test_performance_tables_option_output() -> None: + test = { + "input": {"salary": {"2017-01": 2000}}, + "output": {"salary": {"2017-01": 2000}}, + } test_item = TestItem(test) - test_item.options = {'performance_tables': True} + test_item.options = {"performance_tables": True} paths = ["performance_table.csv", "aggregated_performance_table.csv"] @@ -176,7 +191,7 @@ def test_performance_tables_option_output(): clean_performance_files(paths) -def clean_performance_files(paths: List[str]): +def clean_performance_files(paths: list[str]) -> None: for path in paths: if os.path.isfile(path): os.remove(path) diff --git a/tests/core/variables/test_annualize.py b/tests/core/variables/test_annualize.py index 62b0a79b14..58ea1372dd 100644 --- a/tests/core/variables/test_annualize.py +++ b/tests/core/variables/test_annualize.py @@ -1,25 +1,25 @@ -import numpy as np +import numpy from pytest import fixture -from openfisca_core import periods -from openfisca_core.model_api import * # noqa analysis:ignore from openfisca_country_template.entities import Person -from openfisca_core.variables import get_annualized_variable + +from openfisca_core import periods +from openfisca_core.periods import DateUnit +from openfisca_core.variables import Variable, get_annualized_variable @fixture def monthly_variable(): - calculation_count = 0 class monthly_variable(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH - def formula(person, period, parameters): + def formula(self, period, parameters): variable.calculation_count += 1 - return np.asarray([100]) + return numpy.asarray([100]) variable = monthly_variable() variable.calculation_count = calculation_count @@ -30,55 +30,57 @@ def formula(person, period, parameters): class PopulationMock: # Simulate a population for whom a variable has already been put in cache for January. - def __init__(self, variable): + def __init__(self, variable) -> None: self.variable = variable - def __call__(self, variable_name, period): + def __call__(self, variable_name: str, period): if period.start.month == 1: - return np.asarray([100]) - else: - return self.variable.get_formula(period)(self, period, None) + return numpy.asarray([100]) + return self.variable.get_formula(period)(self, period, None) -def test_without_annualize(monthly_variable): +def test_without_annualize(monthly_variable) -> None: period = periods.period(2019) person = PopulationMock(monthly_variable) yearly_sum = sum( - person('monthly_variable', month) - for month in period.get_subperiods(MONTH) - ) + person("monthly_variable", month) + for month in period.get_subperiods(DateUnit.MONTH) + ) assert monthly_variable.calculation_count == 11 assert yearly_sum == 1200 -def test_with_annualize(monthly_variable): +def test_with_annualize(monthly_variable) -> None: period = periods.period(2019) annualized_variable = get_annualized_variable(monthly_variable) person = PopulationMock(annualized_variable) yearly_sum = sum( - person('monthly_variable', month) - for month in period.get_subperiods(MONTH) - ) + person("monthly_variable", month) + for month in period.get_subperiods(DateUnit.MONTH) + ) assert monthly_variable.calculation_count == 0 assert yearly_sum == 100 * 12 -def test_with_partial_annualize(monthly_variable): - period = periods.period('year:2018:2') - annualized_variable = get_annualized_variable(monthly_variable, periods.period(2018)) +def test_with_partial_annualize(monthly_variable) -> None: + period = periods.period("year:2018:2") + annualized_variable = get_annualized_variable( + monthly_variable, + periods.period(2018), + ) person = PopulationMock(annualized_variable) yearly_sum = sum( - person('monthly_variable', month) - for month in period.get_subperiods(MONTH) - ) + person("monthly_variable", month) + for month in period.get_subperiods(DateUnit.MONTH) + ) assert monthly_variable.calculation_count == 11 assert yearly_sum == 100 * 12 * 2 diff --git a/tests/core/variables/test_definition_period.py b/tests/core/variables/test_definition_period.py new file mode 100644 index 0000000000..8ef9bfaa87 --- /dev/null +++ b/tests/core/variables/test_definition_period.py @@ -0,0 +1,43 @@ +import pytest + +from openfisca_core import periods +from openfisca_core.variables import Variable + + +@pytest.fixture +def variable(persons): + class TestVariable(Variable): + value_type = float + entity = persons + + return TestVariable + + +def test_weekday_variable(variable) -> None: + variable.definition_period = periods.WEEKDAY + assert variable() + + +def test_week_variable(variable) -> None: + variable.definition_period = periods.WEEK + assert variable() + + +def test_day_variable(variable) -> None: + variable.definition_period = periods.DAY + assert variable() + + +def test_month_variable(variable) -> None: + variable.definition_period = periods.MONTH + assert variable() + + +def test_year_variable(variable) -> None: + variable.definition_period = periods.YEAR + assert variable() + + +def test_eternity_variable(variable) -> None: + variable.definition_period = periods.ETERNITY + assert variable() diff --git a/tests/core/variables/test_variables.py b/tests/core/variables/test_variables.py index 876145bde1..d5d85a70d9 100644 --- a/tests/core/variables/test_variables.py +++ b/tests/core/variables/test_variables.py @@ -1,17 +1,15 @@ -# -*- coding: utf-8 -*- - import datetime -from openfisca_core.model_api import Variable -from openfisca_core.periods import MONTH, ETERNITY -from openfisca_core.simulation_builder import SimulationBuilder -from openfisca_core.tools import assert_near +from pytest import fixture, mark, raises import openfisca_country_template as country_template import openfisca_country_template.situation_examples from openfisca_country_template.entities import Person -from pytest import fixture, raises, mark +from openfisca_core.periods import DateUnit +from openfisca_core.simulation_builder import SimulationBuilder +from openfisca_core.tools import assert_near +from openfisca_core.variables import Variable # Check which date is applied whether it comes from Variable attribute (end) # or formula(s) dates. @@ -22,27 +20,39 @@ # HELPERS + @fixture def couple(): - return SimulationBuilder().build_from_entities(tax_benefit_system, openfisca_country_template.situation_examples.couple) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + openfisca_country_template.situation_examples.couple, + ) @fixture def simulation(): - return SimulationBuilder().build_from_entities(tax_benefit_system, openfisca_country_template.situation_examples.single) + return SimulationBuilder().build_from_entities( + tax_benefit_system, + openfisca_country_template.situation_examples.single, + ) def vectorize(individu, number): return individu.filled_array(number) -def check_error_at_add_variable(tax_benefit_system, variable, error_message_prefix): +def check_error_at_add_variable( + tax_benefit_system, variable, error_message_prefix +) -> None: try: tax_benefit_system.add_variable(variable) except ValueError as e: message = get_message(e) if not message or not message.startswith(error_message_prefix): - raise AssertionError('Incorrect error message. Was expecting something starting by "{}". Got: "{}"'.format(error_message_prefix, message)) + msg = f'Incorrect error message. Was expecting something starting by "{error_message_prefix}". Got: "{message}"' + raise AssertionError( + msg, + ) def get_message(error): @@ -58,104 +68,111 @@ def get_message(error): class variable__no_date(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without date." -def test_before_add__variable__no_date(): - assert tax_benefit_system.variables.get('variable__no_date') is None +def test_before_add__variable__no_date() -> None: + assert tax_benefit_system.variables.get("variable__no_date") is None -def test_variable__no_date(): +def test_variable__no_date() -> None: tax_benefit_system.add_variable(variable__no_date) - variable = tax_benefit_system.variables['variable__no_date'] + variable = tax_benefit_system.variables["variable__no_date"] assert variable.end is None assert len(variable.formulas) == 0 + # end, no formula class variable__strange_end_attribute(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with dubious end attribute, no formula." - end = '1989-00-00' + end = "1989-00-00" -def test_variable__strange_end_attribute(): +def test_variable__strange_end_attribute() -> None: try: tax_benefit_system.add_variable(variable__strange_end_attribute) except ValueError as e: message = get_message(e) - assert message.startswith("Incorrect 'end' attribute format in 'variable__strange_end_attribute'.") + assert message.startswith( + "Incorrect 'end' attribute format in 'variable__strange_end_attribute'.", + ) # Check that Error at variable adding prevents it from registration in the taxbenefitsystem. - assert not tax_benefit_system.variables.get('variable__strange_end_attribute') + assert not tax_benefit_system.variables.get("variable__strange_end_attribute") # end, no formula + class variable__end_attribute(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, no formula." - end = '1989-12-31' + end = "1989-12-31" tax_benefit_system.add_variable(variable__end_attribute) -def test_variable__end_attribute(): - variable = tax_benefit_system.variables['variable__end_attribute'] +def test_variable__end_attribute() -> None: + variable = tax_benefit_system.variables["variable__end_attribute"] assert variable.end == datetime.date(1989, 12, 31) -def test_variable__end_attribute_set_input(simulation): - month_before_end = '1989-01' - month_after_end = '1990-01' - simulation.set_input('variable__end_attribute', month_before_end, 10) - simulation.set_input('variable__end_attribute', month_after_end, 10) - assert simulation.calculate('variable__end_attribute', month_before_end) == 10 - assert simulation.calculate('variable__end_attribute', month_after_end) == 0 +def test_variable__end_attribute_set_input(simulation) -> None: + month_before_end = "1989-01" + month_after_end = "1990-01" + simulation.set_input("variable__end_attribute", month_before_end, 10) + simulation.set_input("variable__end_attribute", month_after_end, 10) + assert simulation.calculate("variable__end_attribute", month_before_end) == 10 + assert simulation.calculate("variable__end_attribute", month_after_end) == 0 # end, one formula without date + class end_attribute__one_simple_formula(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, one formula without date." - end = '1989-12-31' + end = "1989-12-31" - def formula(individu, period): - return vectorize(individu, 100) + def formula(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_simple_formula) -def test_formulas_attributes_single_formula(): - formulas = tax_benefit_system.variables['end_attribute__one_simple_formula'].formulas - assert formulas['0001-01-01'] is not None +def test_formulas_attributes_single_formula() -> None: + formulas = tax_benefit_system.variables[ + "end_attribute__one_simple_formula" + ].formulas + assert formulas["0001-01-01"] is not None -def test_call__end_attribute__one_simple_formula(simulation): - month = '1979-12' - assert simulation.calculate('end_attribute__one_simple_formula', month) == 100 +def test_call__end_attribute__one_simple_formula(simulation) -> None: + month = "1979-12" + assert simulation.calculate("end_attribute__one_simple_formula", month) == 100 - month = '1989-12' - assert simulation.calculate('end_attribute__one_simple_formula', month) == 100 + month = "1989-12" + assert simulation.calculate("end_attribute__one_simple_formula", month) == 100 - month = '1990-01' - assert simulation.calculate('end_attribute__one_simple_formula', month) == 0 + month = "1990-01" + assert simulation.calculate("end_attribute__one_simple_formula", month) == 0 -def test_dates__end_attribute__one_simple_formula(): - variable = tax_benefit_system.variables['end_attribute__one_simple_formula'] +def test_dates__end_attribute__one_simple_formula() -> None: + variable = tax_benefit_system.variables["end_attribute__one_simple_formula"] assert variable.end == datetime.date(1989, 12, 31) assert len(variable.formulas) == 1 @@ -167,86 +184,93 @@ def test_dates__end_attribute__one_simple_formula(): # formula, strange name + class no_end_attribute__one_formula__strange_name(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without end attribute, one stangely named formula." - def formula_2015_toto(individu, period): - return vectorize(individu, 100) + def formula_2015_toto(self, period): + return vectorize(self, 100) -def test_add__no_end_attribute__one_formula__strange_name(): - check_error_at_add_variable(tax_benefit_system, no_end_attribute__one_formula__strange_name, - 'Unrecognized formula name in variable "no_end_attribute__one_formula__strange_name". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: ') +def test_add__no_end_attribute__one_formula__strange_name() -> None: + check_error_at_add_variable( + tax_benefit_system, + no_end_attribute__one_formula__strange_name, + 'Unrecognized formula name in variable "no_end_attribute__one_formula__strange_name". Expecting "formula_YYYY" or "formula_YYYY_MM" or "formula_YYYY_MM_DD where YYYY, MM and DD are year, month and day. Found: ', + ) # formula, start + class no_end_attribute__one_formula__start(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__start) -def test_call__no_end_attribute__one_formula__start(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__one_formula__start', month) == 0 +def test_call__no_end_attribute__one_formula__start(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__one_formula__start", month) == 0 - month = '2000-05' - assert simulation.calculate('no_end_attribute__one_formula__start', month) == 100 + month = "2000-05" + assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100 - month = '2020-01' - assert simulation.calculate('no_end_attribute__one_formula__start', month) == 100 + month = "2020-01" + assert simulation.calculate("no_end_attribute__one_formula__start", month) == 100 -def test_dates__no_end_attribute__one_formula__start(): - variable = tax_benefit_system.variables['no_end_attribute__one_formula__start'] +def test_dates__no_end_attribute__one_formula__start() -> None: + variable = tax_benefit_system.variables["no_end_attribute__one_formula__start"] assert variable.end is None assert len(variable.formulas) == 1 - assert variable.formulas.keys()[0] == '2000-01-01' + assert variable.formulas.keys()[0] == "2000-01-01" class no_end_attribute__one_formula__eternity(Variable): value_type = int entity = Person - definition_period = ETERNITY # For this entity, this variable shouldn't evolve through time + definition_period = ( + DateUnit.ETERNITY + ) # For this entity, this variable shouldn't evolve through time label = "Variable without end attribute, one dated formula." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(no_end_attribute__one_formula__eternity) @mark.xfail() -def test_call__no_end_attribute__one_formula__eternity(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 0 +def test_call__no_end_attribute__one_formula__eternity(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 # This fails because a definition period of "ETERNITY" caches for all periods - month = '2000-01' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 100 + month = "2000-01" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 -def test_call__no_end_attribute__one_formula__eternity_before(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 0 +def test_call__no_end_attribute__one_formula__eternity_before(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 0 -def test_call__no_end_attribute__one_formula__eternity_after(simulation): - month = '2000-01' - assert simulation.calculate('no_end_attribute__one_formula__eternity', month) == 100 +def test_call__no_end_attribute__one_formula__eternity_after(simulation) -> None: + month = "2000-01" + assert simulation.calculate("no_end_attribute__one_formula__eternity", month) == 100 # formula, different start formats @@ -255,97 +279,123 @@ def test_call__no_end_attribute__one_formula__eternity_after(simulation): class no_end_attribute__formulas__start_formats(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable without end attribute, multiple dated formulas." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2010_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_end_attribute__formulas__start_formats) -def test_formulas_attributes_dated_formulas(): - formulas = tax_benefit_system.variables['no_end_attribute__formulas__start_formats'].formulas - assert(len(formulas) == 2) - assert formulas['2000-01-01'] is not None - assert formulas['2010-01-01'] is not None +def test_formulas_attributes_dated_formulas() -> None: + formulas = tax_benefit_system.variables[ + "no_end_attribute__formulas__start_formats" + ].formulas + assert len(formulas) == 2 + assert formulas["2000-01-01"] is not None + assert formulas["2010-01-01"] is not None -def test_get_formulas(): - variable = tax_benefit_system.variables['no_end_attribute__formulas__start_formats'] - formula_2000 = variable.formulas['2000-01-01'] - formula_2010 = variable.formulas['2010-01-01'] +def test_get_formulas() -> None: + variable = tax_benefit_system.variables["no_end_attribute__formulas__start_formats"] + formula_2000 = variable.formulas["2000-01-01"] + formula_2010 = variable.formulas["2010-01-01"] - assert variable.get_formula('1999-01') is None - assert variable.get_formula('2000-01') == formula_2000 - assert variable.get_formula('2009-12') == formula_2000 - assert variable.get_formula('2009-12-31') == formula_2000 - assert variable.get_formula('2010-01') == formula_2010 - assert variable.get_formula('2010-01-01') == formula_2010 + assert variable.get_formula("1999-01") is None + assert variable.get_formula("2000-01") == formula_2000 + assert variable.get_formula("2009-12") == formula_2000 + assert variable.get_formula("2009-12-31") == formula_2000 + assert variable.get_formula("2010-01") == formula_2010 + assert variable.get_formula("2010-01-01") == formula_2010 -def test_call__no_end_attribute__formulas__start_formats(simulation): - month = '1999-12' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 0 +def test_call__no_end_attribute__formulas__start_formats(simulation) -> None: + month = "1999-12" + assert simulation.calculate("no_end_attribute__formulas__start_formats", month) == 0 - month = '2000-01' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 100 + month = "2000-01" + assert ( + simulation.calculate("no_end_attribute__formulas__start_formats", month) == 100 + ) - month = '2009-12' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 100 + month = "2009-12" + assert ( + simulation.calculate("no_end_attribute__formulas__start_formats", month) == 100 + ) - month = '2010-01' - assert simulation.calculate('no_end_attribute__formulas__start_formats', month) == 200 + month = "2010-01" + assert ( + simulation.calculate("no_end_attribute__formulas__start_formats", month) == 200 + ) # Multiple formulas, different names with date overlap + class no_attribute__formulas__different_names__dates_overlap(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names but same dates." - def formula_2000(individu, period): - return vectorize(individu, 100) + def formula_2000(self, period): + return vectorize(self, 100) - def formula_2000_01_01(individu, period): - return vectorize(individu, 200) + def formula_2000_01_01(self, period): + return vectorize(self, 200) -def test_add__no_attribute__formulas__different_names__dates_overlap(): +def test_add__no_attribute__formulas__different_names__dates_overlap() -> None: # Variable isn't registered in the taxbenefitsystem - check_error_at_add_variable(tax_benefit_system, no_attribute__formulas__different_names__dates_overlap, "Dated formulas overlap") + check_error_at_add_variable( + tax_benefit_system, + no_attribute__formulas__different_names__dates_overlap, + "Dated formulas overlap", + ) # formula(start), different names, no date overlap + class no_attribute__formulas__different_names__no_overlap(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable, no end attribute, multiple dated formulas with different names and no date overlap." - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2010_01_01(individu, period): - return vectorize(individu, 200) + def formula_2010_01_01(self, period): + return vectorize(self, 200) tax_benefit_system.add_variable(no_attribute__formulas__different_names__no_overlap) -def test_call__no_attribute__formulas__different_names__no_overlap(simulation): - month = '2009-12' - assert simulation.calculate('no_attribute__formulas__different_names__no_overlap', month) == 100 +def test_call__no_attribute__formulas__different_names__no_overlap(simulation) -> None: + month = "2009-12" + assert ( + simulation.calculate( + "no_attribute__formulas__different_names__no_overlap", + month, + ) + == 100 + ) - month = '2015-05' - assert simulation.calculate('no_attribute__formulas__different_names__no_overlap', month) == 200 + month = "2015-05" + assert ( + simulation.calculate( + "no_attribute__formulas__different_names__no_overlap", + month, + ) + == 200 + ) # END ATTRIBUTE - DATED FORMULA(S) @@ -353,123 +403,145 @@ def test_call__no_attribute__formulas__different_names__no_overlap(simulation): # formula, start. + class end_attribute__one_formula__start(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, one dated formula." - end = '2001-12-31' + end = "2001-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute__one_formula__start) -def test_call__end_attribute__one_formula__start(simulation): - month = '1980-01' - assert simulation.calculate('end_attribute__one_formula__start', month) == 0 +def test_call__end_attribute__one_formula__start(simulation) -> None: + month = "1980-01" + assert simulation.calculate("end_attribute__one_formula__start", month) == 0 - month = '2000-01' - assert simulation.calculate('end_attribute__one_formula__start', month) == 100 + month = "2000-01" + assert simulation.calculate("end_attribute__one_formula__start", month) == 100 - month = '2002-01' - assert simulation.calculate('end_attribute__one_formula__start', month) == 0 + month = "2002-01" + assert simulation.calculate("end_attribute__one_formula__start", month) == 0 # end < formula, start. + class stop_attribute_before__one_formula__start(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with stop attribute only coming before formula start." - end = '1990-01-01' + end = "1990-01-01" - def formula_2000_01_01(individu, period): - return vectorize(individu, 0) + def formula_2000_01_01(self, period): + return vectorize(self, 0) -def test_add__stop_attribute_before__one_formula__start(): - check_error_at_add_variable(tax_benefit_system, stop_attribute_before__one_formula__start, 'You declared that "stop_attribute_before__one_formula__start" ends on "1990-01-01", but you wrote a formula to calculate it from "2000-01-01"') +def test_add__stop_attribute_before__one_formula__start() -> None: + check_error_at_add_variable( + tax_benefit_system, + stop_attribute_before__one_formula__start, + 'You declared that "stop_attribute_before__one_formula__start" ends on "1990-01-01", but you wrote a formula to calculate it from "2000-01-01"', + ) # end, formula with dates intervals overlap. + class end_attribute_restrictive__one_formula(Variable): value_type = int entity = Person - definition_period = MONTH - label = "Variable with end attribute, one dated formula and dates intervals overlap." - end = '2001-01-01' + definition_period = DateUnit.MONTH + label = ( + "Variable with end attribute, one dated formula and dates intervals overlap." + ) + end = "2001-01-01" - def formula_2001_01_01(individu, period): - return vectorize(individu, 100) + def formula_2001_01_01(self, period): + return vectorize(self, 100) tax_benefit_system.add_variable(end_attribute_restrictive__one_formula) -def test_call__end_attribute_restrictive__one_formula(simulation): - month = '2000-12' - assert simulation.calculate('end_attribute_restrictive__one_formula', month) == 0 +def test_call__end_attribute_restrictive__one_formula(simulation) -> None: + month = "2000-12" + assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0 - month = '2001-01' - assert simulation.calculate('end_attribute_restrictive__one_formula', month) == 100 + month = "2001-01" + assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 100 - month = '2000-05' - assert simulation.calculate('end_attribute_restrictive__one_formula', month) == 0 + month = "2000-05" + assert simulation.calculate("end_attribute_restrictive__one_formula", month) == 0 # formulas of different names (without dates overlap on formulas) + class end_attribute__formulas__different_names(Variable): value_type = int entity = Person - definition_period = MONTH + definition_period = DateUnit.MONTH label = "Variable with end attribute, multiple dated formulas with different names." - end = '2010-12-31' + end = "2010-12-31" - def formula_2000_01_01(individu, period): - return vectorize(individu, 100) + def formula_2000_01_01(self, period): + return vectorize(self, 100) - def formula_2005_01_01(individu, period): - return vectorize(individu, 200) + def formula_2005_01_01(self, period): + return vectorize(self, 200) - def formula_2010_01_01(individu, period): - return vectorize(individu, 300) + def formula_2010_01_01(self, period): + return vectorize(self, 300) tax_benefit_system.add_variable(end_attribute__formulas__different_names) -def test_call__end_attribute__formulas__different_names(simulation): - month = '2000-01' - assert simulation.calculate('end_attribute__formulas__different_names', month) == 100 +def test_call__end_attribute__formulas__different_names(simulation) -> None: + month = "2000-01" + assert ( + simulation.calculate("end_attribute__formulas__different_names", month) == 100 + ) - month = '2005-01' - assert simulation.calculate('end_attribute__formulas__different_names', month) == 200 + month = "2005-01" + assert ( + simulation.calculate("end_attribute__formulas__different_names", month) == 200 + ) - month = '2010-12' - assert simulation.calculate('end_attribute__formulas__different_names', month) == 300 + month = "2010-12" + assert ( + simulation.calculate("end_attribute__formulas__different_names", month) == 300 + ) -def test_get_formula(simulation): +def test_get_formula(simulation) -> None: person = simulation.person - disposable_income_formula = tax_benefit_system.get_variable('disposable_income').get_formula() - disposable_income = person('disposable_income', '2017-01') - disposable_income_2 = disposable_income_formula(person, '2017-01', None) # No need for parameters here + disposable_income_formula = tax_benefit_system.get_variable( + "disposable_income", + ).get_formula() + disposable_income = person("disposable_income", "2017-01") + disposable_income_2 = disposable_income_formula( + person, + "2017-01", + None, + ) # No need for parameters here assert_near(disposable_income, disposable_income_2) -def test_unexpected_attr(): +def test_unexpected_attr() -> None: class variable_with_strange_attr(Variable): value_type = int entity = Person - definition_period = MONTH - unexpected = '???' + definition_period = DateUnit.MONTH + unexpected = "???" with raises(ValueError): tax_benefit_system.add_variable(variable_with_strange_attr) diff --git a/tests/fixtures/appclient.py b/tests/fixtures/appclient.py index a140e0f938..692747d393 100644 --- a/tests/fixtures/appclient.py +++ b/tests/fixtures/appclient.py @@ -5,7 +5,7 @@ @pytest.fixture(scope="module") def test_client(tax_benefit_system): - """ This module-scoped fixture creates an API client for the TBS defined in the `tax_benefit_system` + """This module-scoped fixture creates an API client for the TBS defined in the `tax_benefit_system` fixture. This `tax_benefit_system` is mutable, so you can add/update variables. Example: @@ -15,20 +15,22 @@ def test_client(tax_benefit_system): from openfisca_country_template import entities from openfisca_core import periods from openfisca_core.variables import Variable + ... + class new_variable(Variable): value_type = float entity = entities.Person - definition_period = periods.MONTH + definition_period = DateUnit.MONTH label = "New variable" reference = "https://law.gov.example/new_variable" # Always use the most official source + tax_benefit_system.add_variable(new_variable) flask_app = app.create_app(tax_benefit_system) """ - # Create the test API client flask_app = app.create_app(tax_benefit_system) return flask_app.test_client() diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 7ab5c60a39..6670a68da1 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -6,38 +6,32 @@ class TestEntity(Entity): - @property - def variables(self): - return self - - def exists(self): - return self - - def isdefined(self): - return self - - def get(self, variable_name): + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result + def check_variable_defined_for_entity(self, variable_name: str) -> bool: + return True -class TestGroupEntity(GroupEntity): - @property - def variables(self): - return self - - def exists(self): - return self - - def isdefined(self): - return self - def get(self, variable_name): +class TestGroupEntity(GroupEntity): + def get_variable( + self, + variable_name: str, + check_existence: bool = False, + ) -> TestVariable: result = TestVariable(self) result.name = variable_name return result + def check_variable_defined_for_entity(self, variable_name: str) -> bool: + return True + @pytest.fixture def persons(): @@ -46,13 +40,9 @@ def persons(): @pytest.fixture def households(): - roles = [{ - 'key': 'parent', - 'plural': 'parents', - 'max': 2 - }, { - 'key': 'child', - 'plural': 'children' - }] + roles = [ + {"key": "parent", "plural": "parents", "max": 2}, + {"key": "child", "plural": "children"}, + ] return TestGroupEntity("household", "households", "", "", roles) diff --git a/tests/fixtures/extensions.py b/tests/fixtures/extensions.py new file mode 100644 index 0000000000..bc4e85fe72 --- /dev/null +++ b/tests/fixtures/extensions.py @@ -0,0 +1,18 @@ +from importlib import metadata + +import pytest + + +@pytest.fixture +def test_country_package_name() -> str: + return "openfisca_country_template" + + +@pytest.fixture +def test_extension_package_name() -> str: + return "openfisca_extension_template" + + +@pytest.fixture +def distribution(test_country_package_name): + return metadata.distribution(test_country_package_name) diff --git a/tests/fixtures/simulations.py b/tests/fixtures/simulations.py index 9d343d5ac0..53120b60d9 100644 --- a/tests/fixtures/simulations.py +++ b/tests/fixtures/simulations.py @@ -14,7 +14,7 @@ def simulation(tax_benefit_system, request): tax_benefit_system, variables, period, - ) + ) @pytest.fixture @@ -24,8 +24,4 @@ def make_simulation(): def _simulation(simulation_builder, tax_benefit_system, variables, period): simulation_builder.set_default_period(period) - simulation = \ - simulation_builder \ - .build_from_variables(tax_benefit_system, variables) - - return simulation + return simulation_builder.build_from_variables(tax_benefit_system, variables) diff --git a/tests/fixtures/taxbenefitsystems.py b/tests/fixtures/taxbenefitsystems.py index c2c47071ca..d29dfd73fd 100644 --- a/tests/fixtures/taxbenefitsystems.py +++ b/tests/fixtures/taxbenefitsystems.py @@ -3,7 +3,7 @@ from openfisca_country_template import CountryTaxBenefitSystem -@pytest.fixture(scope = "module") +@pytest.fixture(scope="module") def tax_benefit_system(): return CountryTaxBenefitSystem() diff --git a/tests/fixtures/variables.py b/tests/fixtures/variables.py index cd0d9b70ce..2deccf5891 100644 --- a/tests/fixtures/variables.py +++ b/tests/fixtures/variables.py @@ -1,11 +1,11 @@ -from openfisca_core import periods +from openfisca_core.periods import DateUnit from openfisca_core.variables import Variable class TestVariable(Variable): - definition_period = periods.ETERNITY + definition_period = DateUnit.ETERNITY value_type = float - def __init__(self, entity): + def __init__(self, entity) -> None: self.__class__.entity = entity super().__init__() diff --git a/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml b/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml index a51ae6894e..4928b06711 100644 --- a/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml +++ b/tests/fixtures/yaml_tests/failing_test_absolute_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 351 # 300 + + +- name: "Failing test: result out of variable specific absolute error margin" + period: 2015-01 + absolute_error_margin: + default: 100 + income_tax: 50 + input: + salary: 2000 + output: + income_tax: 351 # 300 diff --git a/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml b/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml index 9258946c3d..c0788cfa96 100644 --- a/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml +++ b/tests/fixtures/yaml_tests/failing_test_relative_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 316 # 300 + + +- name: "Failing test: result out of variable specific relative error margin" + period: 2015-01 + relative_error_margin: + default: 1 + income_tax: 0.05 + input: + salary: 2000 + output: + income_tax: 316 # 300 diff --git a/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml b/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml index be7de2d5cb..65dbb308e3 100644 --- a/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml +++ b/tests/fixtures/yaml_tests/test_absolute_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 350 # 300 + + +- name: "Result within absolute error margin" + period: 2015-01 + absolute_error_margin: + default: 100 + income_tax: 50 + input: + salary: 2000 + output: + income_tax: 350 # 300 diff --git a/tests/fixtures/yaml_tests/test_relative_error_margin.yaml b/tests/fixtures/yaml_tests/test_relative_error_margin.yaml index 7845d6f361..d39a9e4143 100644 --- a/tests/fixtures/yaml_tests/test_relative_error_margin.yaml +++ b/tests/fixtures/yaml_tests/test_relative_error_margin.yaml @@ -5,3 +5,14 @@ salary: 2000 output: income_tax: 290 # 300 + + +- name: "Result within variable relative error margin" + period: 2015-01 + relative_error_margin: + default: .001 + income_tax: 0.05 + input: + salary: 2000 + output: + income_tax: 290 # 300 diff --git a/tests/web_api/__init__.py b/tests/web_api/__init__.py index 8098c2a5a2..e69de29bb2 100644 --- a/tests/web_api/__init__.py +++ b/tests/web_api/__init__.py @@ -1,4 +0,0 @@ -import pkg_resources - -TEST_COUNTRY_PACKAGE_NAME = 'openfisca_country_template' -distribution = pkg_resources.get_distribution(TEST_COUNTRY_PACKAGE_NAME) diff --git a/tests/web_api/basic_case/__init__.py b/tests/web_api/basic_case/__init__.py deleted file mode 100644 index 4114c06467..0000000000 --- a/tests/web_api/basic_case/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -import pkg_resources -from openfisca_web_api.app import create_app -from openfisca_core.scripts import build_tax_benefit_system - -TEST_COUNTRY_PACKAGE_NAME = 'openfisca_country_template' -distribution = pkg_resources.get_distribution(TEST_COUNTRY_PACKAGE_NAME) -tax_benefit_system = build_tax_benefit_system(TEST_COUNTRY_PACKAGE_NAME, extensions = None, reforms = None) -subject = create_app(tax_benefit_system).test_client() diff --git a/tests/web_api/case_with_extension/test_extensions.py b/tests/web_api/case_with_extension/test_extensions.py index 4da94bf45c..2c688232f8 100644 --- a/tests/web_api/case_with_extension/test_extensions.py +++ b/tests/web_api/case_with_extension/test_extensions.py @@ -1,28 +1,39 @@ -# -*- coding: utf-8 -*- +from http import client + +import pytest -from http.client import OK from openfisca_core.scripts import build_tax_benefit_system from openfisca_web_api.app import create_app -TEST_COUNTRY_PACKAGE_NAME = 'openfisca_country_template' -TEST_EXTENSION_PACKAGE_NAMES = ['openfisca_extension_template'] +@pytest.fixture +def tax_benefit_system(test_country_package_name, test_extension_package_name): + return build_tax_benefit_system( + test_country_package_name, + extensions=[test_extension_package_name], + reforms=None, + ) -tax_benefit_system = build_tax_benefit_system(TEST_COUNTRY_PACKAGE_NAME, extensions = TEST_EXTENSION_PACKAGE_NAMES, reforms = None) -extended_subject = create_app(tax_benefit_system).test_client() +@pytest.fixture +def extended_subject(tax_benefit_system): + return create_app(tax_benefit_system).test_client() -def test_return_code(): - parameters_response = extended_subject.get('/parameters') - assert parameters_response.status_code == OK +def test_return_code(extended_subject) -> None: + parameters_response = extended_subject.get("/parameters") + assert parameters_response.status_code == client.OK -def test_return_code_existing_parameter(): - extension_parameter_response = extended_subject.get('/parameter/local_town.child_allowance.amount') - assert extension_parameter_response.status_code == OK +def test_return_code_existing_parameter(extended_subject) -> None: + extension_parameter_response = extended_subject.get( + "/parameter/local_town.child_allowance.amount", + ) + assert extension_parameter_response.status_code == client.OK -def test_return_code_existing_variable(): - extension_variable_response = extended_subject.get('/variable/local_town_child_allowance') - assert extension_variable_response.status_code == OK +def test_return_code_existing_variable(extended_subject) -> None: + extension_variable_response = extended_subject.get( + "/variable/local_town_child_allowance", + ) + assert extension_variable_response.status_code == client.OK diff --git a/tests/web_api/case_with_reform/test_reforms.py b/tests/web_api/case_with_reform/test_reforms.py index 5037a4b395..f0895cf189 100644 --- a/tests/web_api/case_with_reform/test_reforms.py +++ b/tests/web_api/case_with_reform/test_reforms.py @@ -1,62 +1,65 @@ import http + import pytest from openfisca_core import scripts from openfisca_web_api import app -TEST_COUNTRY_PACKAGE_NAME = "openfisca_country_template" -TEST_REFORMS_PATHS = [ - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.add_dynamic_variable.add_dynamic_variable", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.add_new_tax.add_new_tax", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.flat_social_security_contribution.flat_social_security_contribution", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.modify_social_security_taxation.modify_social_security_taxation", - f"{TEST_COUNTRY_PACKAGE_NAME}.reforms.removal_basic_income.removal_basic_income", + +@pytest.fixture +def test_reforms_path(test_country_package_name): + return [ + f"{test_country_package_name}.reforms.add_dynamic_variable.add_dynamic_variable", + f"{test_country_package_name}.reforms.add_new_tax.add_new_tax", + f"{test_country_package_name}.reforms.flat_social_security_contribution.flat_social_security_contribution", + f"{test_country_package_name}.reforms.modify_social_security_taxation.modify_social_security_taxation", + f"{test_country_package_name}.reforms.removal_basic_income.removal_basic_income", ] # Create app as in 'openfisca serve' script @pytest.fixture -def client(): +def client(test_country_package_name, test_reforms_path): tax_benefit_system = scripts.build_tax_benefit_system( - TEST_COUNTRY_PACKAGE_NAME, - extensions = None, - reforms = TEST_REFORMS_PATHS, - ) + test_country_package_name, + extensions=None, + reforms=test_reforms_path, + ) return app.create_app(tax_benefit_system).test_client() -def test_return_code_of_dynamic_variable(client): +def test_return_code_of_dynamic_variable(client) -> None: result = client.get("/variable/goes_to_school") assert result.status_code == http.client.OK -def test_return_code_of_has_car_variable(client): +def test_return_code_of_has_car_variable(client) -> None: result = client.get("/variable/has_car") assert result.status_code == http.client.OK -def test_return_code_of_new_tax_variable(client): +def test_return_code_of_new_tax_variable(client) -> None: result = client.get("/variable/new_tax") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_variable(client): +def test_return_code_of_social_security_contribution_variable(client) -> None: result = client.get("/variable/social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_social_security_contribution_parameter(client): +def test_return_code_of_social_security_contribution_parameter(client) -> None: result = client.get("/parameter/taxes.social_security_contribution") assert result.status_code == http.client.OK -def test_return_code_of_basic_income_variable(client): +def test_return_code_of_basic_income_variable(client) -> None: result = client.get("/variable/basic_income") assert result.status_code == http.client.OK diff --git a/tests/web_api/loader/test_parameters.py b/tests/web_api/loader/test_parameters.py index e17472a9d6..f44632ce49 100644 --- a/tests/web_api/loader/test_parameters.py +++ b/tests/web_api/loader/test_parameters.py @@ -1,35 +1,70 @@ -# -*- coding: utf-8 -*- - from openfisca_core.parameters import Scale - -from openfisca_web_api.loader.parameters import build_api_scale, build_api_parameter +from openfisca_web_api.loader.parameters import build_api_parameter, build_api_scale -def test_build_rate_scale(): - '''Extracts a 'rate' children from a bracket collection''' - data = {'brackets': [{'rate': {'2014-01-01': {'value': 0.5}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - rate = Scale('this rate', data, None) - assert build_api_scale(rate, 'rate') == {'2014-01-01': {1: 0.5}} +def test_build_rate_scale() -> None: + """Extracts a 'rate' children from a bracket collection.""" + data = { + "brackets": [ + { + "rate": {"2014-01-01": {"value": 0.5}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + rate = Scale("this rate", data, None) + assert build_api_scale(rate, "rate") == {"2014-01-01": {1: 0.5}} -def test_build_amount_scale(): - '''Extracts an 'amount' children from a bracket collection''' - data = {'brackets': [{'amount': {'2014-01-01': {'value': 0}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - rate = Scale('that amount', data, None) - assert build_api_scale(rate, 'amount') == {'2014-01-01': {1: 0}} +def test_build_amount_scale() -> None: + """Extracts an 'amount' children from a bracket collection.""" + data = { + "brackets": [ + { + "amount": {"2014-01-01": {"value": 0}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + rate = Scale("that amount", data, None) + assert build_api_scale(rate, "amount") == {"2014-01-01": {1: 0}} -def test_full_rate_scale(): - '''Serializes a 'rate' scale parameter''' - data = {'brackets': [{'rate': {'2014-01-01': {'value': 0.5}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - scale = Scale('rate', data, None) +def test_full_rate_scale() -> None: + """Serializes a 'rate' scale parameter.""" + data = { + "brackets": [ + { + "rate": {"2014-01-01": {"value": 0.5}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + scale = Scale("rate", data, None) api_scale = build_api_parameter(scale, {}) - assert api_scale == {'description': None, 'id': 'rate', 'metadata': {}, 'brackets': {'2014-01-01': {1: 0.5}}} + assert api_scale == { + "description": None, + "id": "rate", + "metadata": {}, + "brackets": {"2014-01-01": {1: 0.5}}, + } -def test_walk_node_amount_scale(): - '''Serializes an 'amount' scale parameter ''' - data = {'brackets': [{'amount': {'2014-01-01': {'value': 0}}, 'threshold': {'2014-01-01': {'value': 1}}}]} - scale = Scale('amount', data, None) +def test_walk_node_amount_scale() -> None: + """Serializes an 'amount' scale parameter.""" + data = { + "brackets": [ + { + "amount": {"2014-01-01": {"value": 0}}, + "threshold": {"2014-01-01": {"value": 1}}, + }, + ], + } + scale = Scale("amount", data, None) api_scale = build_api_parameter(scale, {}) - assert api_scale == {'description': None, 'id': 'amount', 'metadata': {}, 'brackets': {'2014-01-01': {1: 0}}} + assert api_scale == { + "description": None, + "id": "amount", + "metadata": {}, + "brackets": {"2014-01-01": {1: 0}}, + } diff --git a/tests/web_api/test_calculate.py b/tests/web_api/test_calculate.py index b3415810a7..4d69dae9ab 100644 --- a/tests/web_api/test_calculate.py +++ b/tests/web_api/test_calculate.py @@ -1,351 +1,464 @@ import copy -import dpath import json -from http import client import os +from http import client + +import dpath.util import pytest from openfisca_country_template.situation_examples import couple -def post_json(client, data = None, file = None): +def post_json(client, data=None, file=None): if file: - file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', file) - with open(file_path, 'r') as file: + file_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "assets", + file, + ) + with open(file_path) as file: data = file.read() - return client.post('/calculate', data = data, content_type = 'application/json') + return client.post("/calculate", data=data, content_type="application/json") -def check_response(client, data, expected_error_code, path_to_check, content_to_check): +def check_response( + client, data, expected_error_code, path_to_check, content_to_check +) -> None: response = post_json(client, data) assert response.status_code == expected_error_code - json_response = json.loads(response.data.decode('utf-8')) + json_response = json.loads(response.data.decode("utf-8")) if path_to_check: content = dpath.util.get(json_response, path_to_check) assert content_to_check in content -@pytest.mark.parametrize("test", [ - ('{"a" : "x", "b"}', client.BAD_REQUEST, 'error', 'Invalid JSON'), - ('["An", "array"]', client.BAD_REQUEST, 'error', 'Invalid type'), - ('{"persons": {}}', client.BAD_REQUEST, 'persons', 'At least one person'), - ('{"persons": {"bob": {}}, "unknown_entity": {}}', client.BAD_REQUEST, 'unknown_entity', 'entities are not found',), - ('{"persons": {"bob": {}}, "households": {"dupont": {"parents": {}}}}', client.BAD_REQUEST, 'households/dupont/parents', 'type',), - ('{"persons": {"bob": {"unknown_variable": {}}}}', client.NOT_FOUND, 'persons/bob/unknown_variable', 'You tried to calculate or to set',), - ('{"persons": {"bob": {"housing_allowance": {}}}}', client.BAD_REQUEST, 'persons/bob/housing_allowance', "You tried to compute the variable 'housing_allowance' for the entity 'persons'",), - ('{"persons": {"bob": {"salary": 4000 }}}', client.BAD_REQUEST, 'persons/bob/salary', 'period',), - ('{"persons": {"bob": {"salary": {"2017-01": "toto"} }}}', client.BAD_REQUEST, 'persons/bob/salary/2017-01', 'expected type number',), - ('{"persons": {"bob": {"salary": {"2017-01": {}} }}}', client.BAD_REQUEST, 'persons/bob/salary/2017-01', 'expected type number',), - ('{"persons": {"bob": {"age": {"2017-01": "toto"} }}}', client.BAD_REQUEST, 'persons/bob/age/2017-01', 'expected type integer',), - ('{"persons": {"bob": {"birth": {"2017-01": "toto"} }}}', client.BAD_REQUEST, 'persons/bob/birth/2017-01', 'Can\'t deal with date',), - ('{"persons": {"bob": {}}, "households": {"household": {"parents": ["unexpected_person_id"]}}}', client.BAD_REQUEST, 'households/household/parents', 'has not been declared in persons',), - ('{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", "bob"]}}}', client.BAD_REQUEST, 'households/household/parents', 'has been declared more than once',), - ('{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", {}]}}}', client.BAD_REQUEST, 'households/household/parents/1', 'Invalid type',), - ('{"persons": {"bob": {"salary": {"invalid period": 2000 }}}}', client.BAD_REQUEST, 'persons/bob/salary', 'Expected a period',), - ('{"persons": {"bob": {"salary": {"invalid period": null }}}}', client.BAD_REQUEST, 'persons/bob/salary', 'Expected a period',), - ('{"persons": {"bob": {"basic_income": {"2017": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', client.BAD_REQUEST, 'persons/bob/basic_income/2017', '"basic_income" can only be set for one month',), - ('{"persons": {"bob": {"salary": {"ETERNITY": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', client.BAD_REQUEST, 'persons/bob/salary/ETERNITY', 'salary is only defined for months',), - ('{"persons": {"alice": {}, "bob": {}, "charlie": {}}, "households": {"_": {"parents": ["alice", "bob", "charlie"]}}}', client.BAD_REQUEST, 'households/_/parents', 'at most 2 parents in a household',), - ]) -def test_responses(test_client, test): +@pytest.mark.parametrize( + "test", + [ + ('{"a" : "x", "b"}', client.BAD_REQUEST, "error", "Invalid JSON"), + ('["An", "array"]', client.BAD_REQUEST, "error", "Invalid type"), + ('{"persons": {}}', client.BAD_REQUEST, "persons", "At least one person"), + ( + '{"persons": {"bob": {}}, "unknown_entity": {}}', + client.BAD_REQUEST, + "unknown_entity", + "entities are not found", + ), + ( + '{"persons": {"bob": {}}, "households": {"dupont": {"parents": {}}}}', + client.BAD_REQUEST, + "households/dupont/parents", + "type", + ), + ( + '{"persons": {"bob": {"unknown_variable": {}}}}', + client.NOT_FOUND, + "persons/bob/unknown_variable", + "You tried to calculate or to set", + ), + ( + '{"persons": {"bob": {"housing_allowance": {}}}}', + client.BAD_REQUEST, + "persons/bob/housing_allowance", + "You tried to compute the variable 'housing_allowance' for the entity 'persons'", + ), + ( + '{"persons": {"bob": {"salary": 4000 }}}', + client.BAD_REQUEST, + "persons/bob/salary", + "period", + ), + ( + '{"persons": {"bob": {"salary": {"2017-01": "toto"} }}}', + client.BAD_REQUEST, + "persons/bob/salary/2017-01", + "expected type number", + ), + ( + '{"persons": {"bob": {"salary": {"2017-01": {}} }}}', + client.BAD_REQUEST, + "persons/bob/salary/2017-01", + "expected type number", + ), + ( + '{"persons": {"bob": {"age": {"2017-01": "toto"} }}}', + client.BAD_REQUEST, + "persons/bob/age/2017-01", + "expected type integer", + ), + ( + '{"persons": {"bob": {"birth": {"2017-01": "toto"} }}}', + client.BAD_REQUEST, + "persons/bob/birth/2017-01", + "Can't deal with date", + ), + ( + '{"persons": {"bob": {}}, "households": {"household": {"parents": ["unexpected_person_id"]}}}', + client.BAD_REQUEST, + "households/household/parents", + "has not been declared in persons", + ), + ( + '{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", "bob"]}}}', + client.BAD_REQUEST, + "households/household/parents", + "has been declared more than once", + ), + ( + '{"persons": {"bob": {}}, "households": {"household": {"parents": ["bob", {}]}}}', + client.BAD_REQUEST, + "households/household/parents/1", + "Invalid type", + ), + ( + '{"persons": {"bob": {"salary": {"invalid period": 2000 }}}}', + client.BAD_REQUEST, + "persons/bob/salary", + "Expected a period", + ), + ( + '{"persons": {"bob": {"salary": {"invalid period": null }}}}', + client.BAD_REQUEST, + "persons/bob/salary", + "Expected a period", + ), + ( + '{"persons": {"bob": {"basic_income": {"2017": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', + client.BAD_REQUEST, + "persons/bob/basic_income/2017", + '"basic_income" can only be set for one month', + ), + ( + '{"persons": {"bob": {"salary": {"ETERNITY": 2000 }}}, "households": {"household": {"parents": ["bob"]}}}', + client.BAD_REQUEST, + "persons/bob/salary/ETERNITY", + "salary is only defined for months", + ), + ( + '{"persons": {"alice": {}, "bob": {}, "charlie": {}}, "households": {"_": {"parents": ["alice", "bob", "charlie"]}}}', + client.BAD_REQUEST, + "households/_/parents", + "at most 2 parents in a household", + ), + ], +) +def test_responses(test_client, test) -> None: check_response(test_client, *test) -def test_basic_calculation(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": { - "birth": { - "2017-12": "1980-01-01" - }, - "age": { - "2017-12": None - }, - "salary": { - "2017-12": 2000 - }, - "basic_income": { - "2017-12": None - }, - "income_tax": { - "2017-12": None - } +def test_basic_calculation(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": { + "birth": {"2017-12": "1980-01-01"}, + "age": {"2017-12": None}, + "salary": {"2017-12": 2000}, + "basic_income": {"2017-12": None}, + "income_tax": {"2017-12": None}, }, - "bob": { - "salary": { - "2017-12": 15000 - }, - "basic_income": { - "2017-12": None - }, - "social_security_contribution": { - "2017-12": None - } + "bob": { + "salary": {"2017-12": 15000}, + "basic_income": {"2017-12": None}, + "social_security_contribution": {"2017-12": None}, }, }, - "households": { - "first_household": { - "parents": ['bill', 'bob'], - "housing_tax": { - "2017": None - }, - "accommodation_size": { - "2017-01": 300 - } + "households": { + "first_household": { + "parents": ["bill", "bob"], + "housing_tax": {"2017": None}, + "accommodation_size": {"2017-01": 300}, }, - } - }) + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - assert dpath.get(response_json, 'persons/bill/basic_income/2017-12') == 600 # Universal basic income - assert dpath.get(response_json, 'persons/bill/income_tax/2017-12') == 300 # 15% of the salary - assert dpath.get(response_json, 'persons/bill/age/2017-12') == 37 # 15% of the salary - assert dpath.get(response_json, 'persons/bob/basic_income/2017-12') == 600 - assert dpath.get(response_json, 'persons/bob/social_security_contribution/2017-12') == 816 # From social_security_contribution.yaml test - assert dpath.get(response_json, 'households/first_household/housing_tax/2017') == 3000 - - -def test_enums_sending_identifier(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {} + response_json = json.loads(response.data.decode("utf-8")) + assert ( + dpath.util.get(response_json, "persons/bill/basic_income/2017-12") == 600 + ) # Universal basic income + assert ( + dpath.util.get(response_json, "persons/bill/income_tax/2017-12") == 300 + ) # 15% of the salary + assert ( + dpath.util.get(response_json, "persons/bill/age/2017-12") == 37 + ) # 15% of the salary + assert dpath.util.get(response_json, "persons/bob/basic_income/2017-12") == 600 + assert ( + dpath.util.get( + response_json, + "persons/bob/social_security_contribution/2017-12", + ) + == 816 + ) # From social_security_contribution.yaml test + assert ( + dpath.util.get(response_json, "households/first_household/housing_tax/2017") + == 3000 + ) + + +def test_enums_sending_identifier(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + "housing_tax": {"2017": None}, + "accommodation_size": {"2017-01": 300}, + "housing_occupancy_status": {"2017-01": "free_lodger"}, + }, }, - "households": { - "_": { - "parents": ["bill"], - "housing_tax": { - "2017": None - }, - "accommodation_size": { - "2017-01": 300 - }, - "housing_occupancy_status": { - "2017-01": "free_lodger" - } - } - } - }) + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - assert dpath.get(response_json, 'households/_/housing_tax/2017') == 0 + response_json = json.loads(response.data.decode("utf-8")) + assert dpath.util.get(response_json, "households/_/housing_tax/2017") == 0 -def test_enum_output(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {}, +def test_enum_output(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": {}, }, - "households": { - "_": { - "parents": ["bill"], - "housing_occupancy_status": { - "2017-01": None - } + "households": { + "_": { + "parents": ["bill"], + "housing_occupancy_status": {"2017-01": None}, }, - } - }) + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - assert dpath.get(response_json, "households/_/housing_occupancy_status/2017-01") == "tenant" - - -def test_enum_wrong_value(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {}, + response_json = json.loads(response.data.decode("utf-8")) + assert ( + dpath.util.get(response_json, "households/_/housing_occupancy_status/2017-01") + == "tenant" + ) + + +def test_enum_wrong_value(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": {}, }, - "households": { - "_": { - "parents": ["bill"], - "housing_occupancy_status": { - "2017-01": "Unknown value lodger" - } + "households": { + "_": { + "parents": ["bill"], + "housing_occupancy_status": {"2017-01": "Unknown value lodger"}, }, - } - }) + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.BAD_REQUEST - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) message = "Possible values are ['owner', 'tenant', 'free_lodger', 'homeless']" - text = dpath.get(response_json, "households/_/housing_occupancy_status/2017-01") + text = dpath.util.get( + response_json, + "households/_/housing_occupancy_status/2017-01", + ) assert message in text -def test_encoding_variable_value(test_client): - simulation_json = json.dumps({ - "persons": { - "toto": {} - }, - "households": { - "_": { - "housing_occupancy_status": { - "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM" - +def test_encoding_variable_value(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"toto": {}}, + "households": { + "_": { + "housing_occupancy_status": { + "2017-07": "Locataire ou sous-locataire d‘un logement loué vide non-HLM", }, - "parent": [ - "toto", - ] - } - } - }) + "parent": [ + "toto", + ], + }, + }, + }, + ) # No UnicodeDecodeError response = post_json(test_client, simulation_json) - assert response.status_code == client.BAD_REQUEST, response.data.decode('utf-8') - response_json = json.loads(response.data.decode('utf-8')) + assert response.status_code == client.BAD_REQUEST, response.data.decode("utf-8") + response_json = json.loads(response.data.decode("utf-8")) message = "'Locataire ou sous-locataire d‘un logement loué vide non-HLM' is not a known value for 'housing_occupancy_status'. Possible values are " - text = dpath.get(response_json, 'households/_/housing_occupancy_status/2017-07') + text = dpath.util.get( + response_json, + "households/_/housing_occupancy_status/2017-07", + ) assert message in text -def test_encoding_entity_name(test_client): - simulation_json = json.dumps({ - "persons": { - "O‘Ryan": {}, - "Renée": {} - }, - "households": { - "_": { - "parents": [ - "O‘Ryan", - "Renée" - ] - } - } - }) +def test_encoding_entity_name(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"O‘Ryan": {}, "Renée": {}}, + "households": {"_": {"parents": ["O‘Ryan", "Renée"]}}, + }, + ) # No UnicodeDecodeError response = post_json(test_client, simulation_json) - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) # In Python 3, there is no encoding issue. if response.status_code != client.OK: message = "'O‘Ryan' is not a valid ASCII value." - text = response_json['error'] + text = response_json["error"] assert message in text -def test_encoding_period_id(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": { - "salary": { - "2017": 60000 - } +def test_encoding_period_id(test_client) -> None: + simulation_json = json.dumps( + { + "persons": { + "bill": {"salary": {"2017": 60000}}, + "bell": {"salary": {"2017": 60000}}, + }, + "households": { + "_": { + "parents": ["bill", "bell"], + "housing_tax": {"à": 400}, + "accommodation_size": {"2017-01": 300}, + "housing_occupancy_status": {"2017-01": "tenant"}, }, - "bell": { - "salary": { - "2017": 60000 - } - } }, - "households": { - "_": { - "parents": ["bill", "bell"], - "housing_tax": { - "à": 400 - }, - "accommodation_size": { - "2017-01": 300 - }, - "housing_occupancy_status": { - "2017-01": "tenant" - } - } - } - }) + }, + ) # No UnicodeDecodeError response = post_json(test_client, simulation_json) assert response.status_code == client.BAD_REQUEST - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) # In Python 3, there is no encoding issue. if "Expected a period" not in str(response.data): message = "'à' is not a valid ASCII value." - text = response_json['error'] + text = response_json["error"] assert message in text -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) - new_couple['households']['_']['postal_code'] = {'2017-01': None} + new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) - response = test_client.post('/calculate', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/calculate", + data=simulation_json, + content_type="application/json", + ) assert response.status_code == client.OK -def test_periods(test_client): - simulation_json = json.dumps({ - "persons": { - "bill": {} +def test_periods(test_client) -> None: + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + "housing_tax": {"2017": None}, + "housing_occupancy_status": {"2017-01": None}, + }, }, - "households": { - "_": { - "parents": ["bill"], - "housing_tax": { - "2017": None - }, - "housing_occupancy_status": { - "2017-01": None - } - } - } - }) + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) + response_json = json.loads(response.data.decode("utf-8")) - yearly_variable = dpath.get(response_json, 'households/_/housing_tax') # web api year is an int - assert yearly_variable == {'2017': 200.0} + yearly_variable = dpath.util.get( + response_json, + "households/_/housing_tax", + ) # web api year is an int + assert yearly_variable == {"2017": 200.0} - monthly_variable = dpath.get(response_json, 'households/_/housing_occupancy_status') # web api month is a string - assert monthly_variable == {'2017-01': 'tenant'} + monthly_variable = dpath.util.get( + response_json, + "households/_/housing_occupancy_status", + ) # web api month is a string + assert monthly_variable == {"2017-01": "tenant"} -def test_handle_period_mismatch_error(test_client): +def test_two_periods(test_client) -> None: + """Test `calculate` on a request with mixed types periods: yearly periods following + monthly or daily periods to check dpath limitation on numeric keys (yearly periods). + Made to test the case where we have more than one path with a numeric in it. + See https://github.com/dpath-maintainers/dpath-python/issues/160 for more informations. + """ + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + "housing_tax": {"2017": None, "2018": None}, + "housing_occupancy_status": {"2017-01": None, "2018-01": None}, + }, + }, + }, + ) + response = post_json(test_client, simulation_json) + assert response.status_code == client.OK + + response_json = json.loads(response.data.decode("utf-8")) + + yearly_variable = dpath.util.get( + response_json, + "households/_/housing_tax", + ) # web api year is an int + assert yearly_variable == {"2017": 200.0, "2018": 200.0} + + monthly_variable = dpath.util.get( + response_json, + "households/_/housing_occupancy_status", + ) # web api month is a string + assert monthly_variable == {"2017-01": "tenant", "2018-01": "tenant"} + + +def test_handle_period_mismatch_error(test_client) -> None: variable = "housing_tax" period = "2017-01" - simulation_json = json.dumps({ - "persons": { - "bill": {} + simulation_json = json.dumps( + { + "persons": {"bill": {}}, + "households": { + "_": { + "parents": ["bill"], + variable: {period: 400}, + }, }, - "households": { - "_": { - "parents": ["bill"], - variable: { - period: 400 - }, - } - } - }) + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.BAD_REQUEST response_json = json.loads(response.data) - error = dpath.get(response_json, f'households/_/housing_tax/{period}') + error = dpath.util.get(response_json, f"households/_/housing_tax/{period}") message = f'Unable to set a value for variable "{variable}" for month-long period "{period}"' assert message in error -def test_gracefully_handle_unexpected_errors(test_client): - """ - Context +def test_gracefully_handle_unexpected_errors(test_client) -> None: + """Context. ======= Whenever an exception is raised by the calculation engine, the API will try @@ -368,19 +481,21 @@ def test_gracefully_handle_unexpected_errors(test_client): variable = "housing_tax" period = "1234-05-06" - simulation_json = json.dumps({ - "persons": { - "bill": {}, + simulation_json = json.dumps( + { + "persons": { + "bill": {}, }, - "households": { - "_": { - "parents": ["bill"], - variable: { - period: None, + "households": { + "_": { + "parents": ["bill"], + variable: { + period: None, }, - } - } - }) + }, + }, + }, + ) response = post_json(test_client, simulation_json) assert response.status_code == client.INTERNAL_SERVER_ERROR diff --git a/tests/web_api/test_entities.py b/tests/web_api/test_entities.py index 6f8153ed37..e7d0ef5b9b 100644 --- a/tests/web_api/test_entities.py +++ b/tests/web_api/test_entities.py @@ -1,37 +1,34 @@ -# -*- coding: utf-8 -*- - -from http import client import json +from http import client from openfisca_country_template import entities - # /entities -def test_return_code(test_client): - entities_response = test_client.get('/entities') +def test_return_code(test_client) -> None: + entities_response = test_client.get("/entities") assert entities_response.status_code == client.OK -def test_response_data(test_client): - entities_response = test_client.get('/entities') - entities_dict = json.loads(entities_response.data.decode('utf-8')) +def test_response_data(test_client) -> None: + entities_response = test_client.get("/entities") + entities_dict = json.loads(entities_response.data.decode("utf-8")) test_documentation = entities.Household.doc.strip() - assert entities_dict['household'] == { - 'description': 'All the people in a family or group who live together in the same place.', - 'documentation': test_documentation, - 'plural': 'households', - 'roles': { - 'child': { - 'description': 'Other individuals living in the household.', - 'plural': 'children', - }, - 'parent': { - 'description': 'The one or two adults in charge of the household.', - 'plural': 'parents', - 'max': 2, - } - } - } + assert entities_dict["household"] == { + "description": "All the people in a family or group who live together in the same place.", + "documentation": test_documentation, + "plural": "households", + "roles": { + "child": { + "description": "Other individuals living in the household.", + "plural": "children", + }, + "parent": { + "description": "The one or two adults in charge of the household.", + "plural": "parents", + "max": 2, + }, + }, + } diff --git a/tests/web_api/test_headers.py b/tests/web_api/test_headers.py index 54bbfd0df8..dc95437a09 100644 --- a/tests/web_api/test_headers.py +++ b/tests/web_api/test_headers.py @@ -1,13 +1,10 @@ -# -*- coding: utf-8 -*- +def test_package_name_header(test_client, distribution) -> None: + name = distribution.metadata.get("Name").lower() + parameters_response = test_client.get("/parameters") + assert parameters_response.headers.get("Country-Package") == name -from . import distribution - -def test_package_name_header(test_client): - parameters_response = test_client.get('/parameters') - assert parameters_response.headers.get('Country-Package') == distribution.key - - -def test_package_version_header(test_client): - parameters_response = test_client.get('/parameters') - assert parameters_response.headers.get('Country-Package-Version') == distribution.version +def test_package_version_header(test_client, distribution) -> None: + version = distribution.metadata.get("Version") + parameters_response = test_client.get("/parameters") + assert parameters_response.headers.get("Country-Package-Version") == version diff --git a/tests/web_api/test_helpers.py b/tests/web_api/test_helpers.py index cb049a0822..a1725cdfbf 100644 --- a/tests/web_api/test_helpers.py +++ b/tests/web_api/test_helpers.py @@ -1,53 +1,51 @@ import os -from openfisca_web_api.loader import parameters - from openfisca_core.parameters import load_parameter_file +from openfisca_web_api.loader import parameters - -dir_path = os.path.join(os.path.dirname(__file__), 'assets') +dir_path = os.path.join(os.path.dirname(__file__), "assets") -def test_build_api_values_history(): - file_path = os.path.join(dir_path, 'test_helpers.yaml') - parameter = load_parameter_file(name='dummy_name', file_path=file_path) +def test_build_api_values_history() -> None: + file_path = os.path.join(dir_path, "test_helpers.yaml") + parameter = load_parameter_file(name="dummy_name", file_path=file_path) values = { - '2017-01-01': 0.02, - '2015-01-01': 0.04, - '2013-01-01': 0.03, - } + "2017-01-01": 0.02, + "2015-01-01": 0.04, + "2013-01-01": 0.03, + } assert parameters.build_api_values_history(parameter) == values -def test_build_api_values_history_with_stop_date(): - file_path = os.path.join(dir_path, 'test_helpers_with_stop_date.yaml') - parameter = load_parameter_file(name='dummy_name', file_path=file_path) +def test_build_api_values_history_with_stop_date() -> None: + file_path = os.path.join(dir_path, "test_helpers_with_stop_date.yaml") + parameter = load_parameter_file(name="dummy_name", file_path=file_path) values = { - '2018-01-01': None, - '2017-01-01': 0.02, - '2015-01-01': 0.04, - '2013-01-01': 0.03, - } + "2018-01-01": None, + "2017-01-01": 0.02, + "2015-01-01": 0.04, + "2013-01-01": 0.03, + } assert parameters.build_api_values_history(parameter) == values -def test_get_value(): - values = {'2013-01-01': 0.03, '2017-01-01': 0.02, '2015-01-01': 0.04} +def test_get_value() -> None: + values = {"2013-01-01": 0.03, "2017-01-01": 0.02, "2015-01-01": 0.04} - assert parameters.get_value('2013-01-01', values) == 0.03 - assert parameters.get_value('2014-01-01', values) == 0.03 - assert parameters.get_value('2015-02-01', values) == 0.04 - assert parameters.get_value('2016-12-31', values) == 0.04 - assert parameters.get_value('2017-01-01', values) == 0.02 - assert parameters.get_value('2018-01-01', values) == 0.02 + assert parameters.get_value("2013-01-01", values) == 0.03 + assert parameters.get_value("2014-01-01", values) == 0.03 + assert parameters.get_value("2015-02-01", values) == 0.04 + assert parameters.get_value("2016-12-31", values) == 0.04 + assert parameters.get_value("2017-01-01", values) == 0.02 + assert parameters.get_value("2018-01-01", values) == 0.02 -def test_get_value_with_none(): - values = {'2015-01-01': 0.04, '2017-01-01': None} +def test_get_value_with_none() -> None: + values = {"2015-01-01": 0.04, "2017-01-01": None} - assert parameters.get_value('2016-12-31', values) == 0.04 - assert parameters.get_value('2017-01-01', values) is None - assert parameters.get_value('2011-01-01', values) is None + assert parameters.get_value("2016-12-31", values) == 0.04 + assert parameters.get_value("2017-01-01", values) is None + assert parameters.get_value("2011-01-01", values) is None diff --git a/tests/web_api/test_parameters.py b/tests/web_api/test_parameters.py index 8f65cca9af..77fee8f7ea 100644 --- a/tests/web_api/test_parameters.py +++ b/tests/web_api/test_parameters.py @@ -1,128 +1,164 @@ -from http import client import json -import pytest import re +from http import client +import pytest # /parameters -GITHUB_URL_REGEX = r'^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/parameters/(.)+\.yaml$' +GITHUB_URL_REGEX = r"^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/parameters/(.)+\.yaml$" -def test_return_code(test_client): - parameters_response = test_client.get('/parameters') +def test_return_code(test_client) -> None: + parameters_response = test_client.get("/parameters") assert parameters_response.status_code == client.OK -def test_response_data(test_client): - parameters_response = test_client.get('/parameters') - parameters = json.loads(parameters_response.data.decode('utf-8')) +def test_response_data(test_client) -> None: + parameters_response = test_client.get("/parameters") + parameters = json.loads(parameters_response.data.decode("utf-8")) - assert parameters['taxes.income_tax_rate'] == { - 'description': 'Income tax rate', - 'href': 'http://localhost/parameter/taxes/income_tax_rate' - } - assert parameters.get('taxes') is None + assert parameters["taxes.income_tax_rate"] == { + "description": "Income tax rate", + "href": "http://localhost/parameter/taxes/income_tax_rate", + } + assert parameters.get("taxes") is None # /parameter/ -def test_error_code_non_existing_parameter(test_client): - response = test_client.get('/parameter/non/existing.parameter') + +def test_error_code_non_existing_parameter(test_client) -> None: + response = test_client.get("/parameter/non/existing.parameter") assert response.status_code == client.NOT_FOUND -def test_return_code_existing_parameter(test_client): - response = test_client.get('/parameter/taxes/income_tax_rate') +def test_return_code_existing_parameter(test_client) -> None: + response = test_client.get("/parameter/taxes/income_tax_rate") assert response.status_code == client.OK -def test_legacy_parameter_route(test_client): - response = test_client.get('/parameter/taxes.income_tax_rate') +def test_legacy_parameter_route(test_client) -> None: + response = test_client.get("/parameter/taxes.income_tax_rate") assert response.status_code == client.OK -def test_parameter_values(test_client): - response = test_client.get('/parameter/taxes/income_tax_rate') +def test_parameter_values(test_client) -> None: + response = test_client.get("/parameter/taxes/income_tax_rate") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['description', 'id', 'metadata', 'source', 'values'] - assert parameter['id'] == 'taxes.income_tax_rate' - assert parameter['description'] == 'Income tax rate' - assert parameter['values'] == {'2015-01-01': 0.15, '2014-01-01': 0.14, '2013-01-01': 0.13, '2012-01-01': 0.16} - assert parameter['metadata'] == {'unit': '/1'} - assert re.match(GITHUB_URL_REGEX, parameter['source']) - assert 'taxes/income_tax_rate.yaml' in parameter['source'] + assert sorted(parameter.keys()), [ + "description", + "id", + "metadata", + "source", + "values", + ] + assert parameter["id"] == "taxes.income_tax_rate" + assert parameter["description"] == "Income tax rate" + assert parameter["values"] == { + "2015-01-01": 0.15, + "2014-01-01": 0.14, + "2013-01-01": 0.13, + "2012-01-01": 0.16, + } + assert parameter["metadata"] == {"unit": "/1"} + assert re.match(GITHUB_URL_REGEX, parameter["source"]) + assert "taxes/income_tax_rate.yaml" in parameter["source"] # 'documentation' attribute exists only when a value is defined - response = test_client.get('/parameter/benefits/housing_allowance') + response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['description', 'documentation', 'id', 'metadata', 'source' == 'values'] + assert sorted(parameter.keys()), [ + "description", + "documentation", + "id", + "metadata", + "source" == "values", + ] assert ( - parameter['documentation'] == - 'A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists.' - ) + parameter["documentation"] + == "A fraction of the rent.\nFrom the 1st of Dec 2016, the housing allowance no longer exists." + ) -def test_parameter_node(tax_benefit_system, test_client): - response = test_client.get('/parameter/benefits') +def test_parameter_node(tax_benefit_system, test_client) -> None: + response = test_client.get("/parameter/benefits") assert response.status_code == client.OK parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['description', 'documentation', 'id', 'metadata', 'source' == 'subparams'] - assert parameter['documentation'] == ( + assert sorted(parameter.keys()), [ + "description", + "documentation", + "id", + "metadata", + "source" == "subparams", + ] + assert parameter["documentation"] == ( "Government support for the citizens and residents of society." "\nThey may be provided to people of any income level, as with social security," "\nbut usually it is intended to ensure that everyone can meet their basic human needs" "\nsuch as food and shelter.\n(See https://en.wikipedia.org/wiki/Welfare)" - ) + ) model_benefits = tax_benefit_system.parameters.benefits - assert parameter['subparams'].keys() == model_benefits.children.keys(), parameter['subparams'].keys() + assert parameter["subparams"].keys() == model_benefits.children.keys(), parameter[ + "subparams" + ].keys() - assert 'description' in parameter['subparams']['basic_income'] - assert parameter['subparams']['basic_income']['description'] == getattr( - model_benefits.basic_income, "description", None - ), parameter['subparams']['basic_income']['description'] + assert "description" in parameter["subparams"]["basic_income"] + assert parameter["subparams"]["basic_income"]["description"] == getattr( + model_benefits.basic_income, + "description", + None, + ), parameter["subparams"]["basic_income"]["description"] -def test_stopped_parameter_values(test_client): - response = test_client.get('/parameter/benefits/housing_allowance') +def test_stopped_parameter_values(test_client) -> None: + response = test_client.get("/parameter/benefits/housing_allowance") parameter = json.loads(response.data) - assert parameter['values'] == {'2016-12-01': None, '2010-01-01': 0.25} + assert parameter["values"] == {"2016-12-01": None, "2010-01-01": 0.25} -def test_scale(test_client): - response = test_client.get('/parameter/taxes/social_security_contribution') +def test_scale(test_client) -> None: + response = test_client.get("/parameter/taxes/social_security_contribution") parameter = json.loads(response.data) - assert sorted(list(parameter.keys())), ['brackets', 'description', 'id', 'metadata' == 'source'] - assert parameter['brackets'] == { - '2013-01-01': {"0.0": 0.03, "12000.0": 0.10}, - '2014-01-01': {"0.0": 0.03, "12100.0": 0.10}, - '2015-01-01': {"0.0": 0.04, "12200.0": 0.12}, - '2016-01-01': {"0.0": 0.04, "12300.0": 0.12}, - '2017-01-01': {"0.0": 0.02, "6000.0": 0.06, "12400.0": 0.12}, - } - - -def check_code(client, route, code): + assert sorted(parameter.keys()), [ + "brackets", + "description", + "id", + "metadata" == "source", + ] + assert parameter["brackets"] == { + "2013-01-01": {"0.0": 0.03, "12000.0": 0.10}, + "2014-01-01": {"0.0": 0.03, "12100.0": 0.10}, + "2015-01-01": {"0.0": 0.04, "12200.0": 0.12}, + "2016-01-01": {"0.0": 0.04, "12300.0": 0.12}, + "2017-01-01": {"0.0": 0.02, "6000.0": 0.06, "12400.0": 0.12}, + } + + +def check_code(client, route, code) -> None: response = client.get(route) assert response.status_code == code -@pytest.mark.parametrize("expected_code", [ - ('/parameters/', client.OK), - ('/parameter', client.NOT_FOUND), - ('/parameter/', client.NOT_FOUND), - ('/parameter/with-ÜNı©ød€', client.NOT_FOUND), - ('/parameter/with%20url%20encoding', client.NOT_FOUND), - ('/parameter/taxes/income_tax_rate/', client.OK), - ('/parameter/taxes/income_tax_rate/too-much-nesting', client.NOT_FOUND), - ('/parameter//taxes/income_tax_rate/', client.NOT_FOUND), - ]) -def test_routes_robustness(test_client, expected_code): +@pytest.mark.parametrize( + "expected_code", + [ + ("/parameters/", client.FOUND), + ("/parameter", client.NOT_FOUND), + ("/parameter/", client.FOUND), + ("/parameter/with-ÜNı©ød€", client.NOT_FOUND), + ("/parameter/with%20url%20encoding", client.NOT_FOUND), + ("/parameter/taxes/income_tax_rate/", client.FOUND), + ("/parameter/taxes/income_tax_rate/too-much-nesting", client.NOT_FOUND), + ("/parameter//taxes/income_tax_rate/", client.FOUND), + ], +) +def test_routes_robustness(test_client, expected_code) -> None: check_code(test_client, *expected_code) -def test_parameter_encoding(test_client): - parameter_response = test_client.get('/parameter/general/age_of_retirement') +def test_parameter_encoding(test_client) -> None: + parameter_response = test_client.get("/parameter/general/age_of_retirement") assert parameter_response.status_code == client.OK diff --git a/tests/web_api/test_spec.py b/tests/web_api/test_spec.py index 5e19752119..75a0f00e64 100644 --- a/tests/web_api/test_spec.py +++ b/tests/web_api/test_spec.py @@ -1,55 +1,77 @@ -import dpath import json -import pytest from http import client +import dpath.util +import pytest +from openapi_spec_validator import OpenAPIV30SpecValidator + -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert sorted(x) == sorted(y) -def test_return_code(test_client): - openAPI_response = test_client.get('/spec') +def test_return_code(test_client) -> None: + openAPI_response = test_client.get("/spec") assert openAPI_response.status_code == client.OK @pytest.fixture(scope="module") def body(test_client): - openAPI_response = test_client.get('/spec') - return json.loads(openAPI_response.data.decode('utf-8')) + openAPI_response = test_client.get("/spec") + return json.loads(openAPI_response.data.decode("utf-8")) -def test_paths(body): +def test_paths(body) -> None: assert_items_equal( - body['paths'], - ["/parameter/{parameterID}", - "/parameters", - "/variable/{variableID}", - "/variables", - "/entities", - "/trace", - "/calculate", - "/spec"] - ) + body["paths"], + [ + "/parameter/{parameterID}", + "/parameters", + "/variable/{variableID}", + "/variables", + "/entities", + "/trace", + "/calculate", + "/spec", + ], + ) -def test_entity_definition(body): - assert 'parents' in dpath.get(body, 'definitions/Household/properties') - assert 'children' in dpath.get(body, 'definitions/Household/properties') - assert 'salary' in dpath.get(body, 'definitions/Person/properties') - assert 'rent' in dpath.get(body, 'definitions/Household/properties') - assert 'number' == dpath.get(body, 'definitions/Person/properties/salary/additionalProperties/type') +def test_entity_definition(body) -> None: + assert "parents" in dpath.util.get(body, "components/schemas/Household/properties") + assert "children" in dpath.util.get(body, "components/schemas/Household/properties") + assert "salary" in dpath.util.get(body, "components/schemas/Person/properties") + assert "rent" in dpath.util.get(body, "components/schemas/Household/properties") + assert ( + dpath.util.get( + body, + "components/schemas/Person/properties/salary/additionalProperties/type", + ) + == "number" + ) -def test_situation_definition(body): - situation_input = body['definitions']['SituationInput'] - situation_output = body['definitions']['SituationOutput'] +def test_situation_definition(body) -> None: + situation_input = body["components"]["schemas"]["SituationInput"] + situation_output = body["components"]["schemas"]["SituationOutput"] for situation in situation_input, situation_output: - assert 'households' in dpath.get(situation, '/properties') - assert 'persons' in dpath.get(situation, '/properties') - assert "#/definitions/Household" == dpath.get(situation, '/properties/households/additionalProperties/$ref') - assert "#/definitions/Person" == dpath.get(situation, '/properties/persons/additionalProperties/$ref') + assert "households" in dpath.util.get(situation, "/properties") + assert "persons" in dpath.util.get(situation, "/properties") + assert ( + dpath.util.get( + situation, + "/properties/households/additionalProperties/$ref", + ) + == "#/components/schemas/Household" + ) + assert ( + dpath.util.get( + situation, + "/properties/persons/additionalProperties/$ref", + ) + == "#/components/schemas/Person" + ) -def test_host(body): - assert 'http' not in body['host'] +def test_respects_spec(body) -> None: + assert not list(OpenAPIV30SpecValidator(body).iter_errors()) diff --git a/tests/web_api/test_trace.py b/tests/web_api/test_trace.py index b59fbdb5f0..9463e69dfb 100644 --- a/tests/web_api/test_trace.py +++ b/tests/web_api/test_trace.py @@ -1,80 +1,129 @@ import copy -import dpath -from http import client import json +from http import client + +import dpath.util -from openfisca_country_template.situation_examples import single, couple +from openfisca_country_template.situation_examples import couple, single -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) -def test_trace_basic(test_client): +def test_trace_basic(test_client) -> None: simulation_json = json.dumps(single) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) assert response.status_code == client.OK - response_json = json.loads(response.data.decode('utf-8')) - disposable_income_value = dpath.util.get(response_json, 'trace/disposable_income<2017-01>/value') + response_json = json.loads(response.data.decode("utf-8")) + disposable_income_value = dpath.util.get( + response_json, + "trace/disposable_income<2017-01>/value", + ) assert isinstance(disposable_income_value, list) assert isinstance(disposable_income_value[0], float) - disposable_income_dep = dpath.util.get(response_json, 'trace/disposable_income<2017-01>/dependencies') + disposable_income_dep = dpath.util.get( + response_json, + "trace/disposable_income<2017-01>/dependencies", + ) assert_items_equal( disposable_income_dep, - ['salary<2017-01>', 'basic_income<2017-01>', 'income_tax<2017-01>', 'social_security_contribution<2017-01>'] - ) - basic_income_dep = dpath.util.get(response_json, 'trace/basic_income<2017-01>/dependencies') - assert_items_equal(basic_income_dep, ['age<2017-01>']) - - -def test_trace_enums(test_client): + [ + "salary<2017-01>", + "basic_income<2017-01>", + "income_tax<2017-01>", + "social_security_contribution<2017-01>", + ], + ) + basic_income_dep = dpath.util.get( + response_json, + "trace/basic_income<2017-01>/dependencies", + ) + assert_items_equal(basic_income_dep, ["age<2017-01>"]) + + +def test_trace_enums(test_client) -> None: new_single = copy.deepcopy(single) - new_single['households']['_']['housing_occupancy_status'] = {"2017-01": None} + new_single["households"]["_"]["housing_occupancy_status"] = {"2017-01": None} simulation_json = json.dumps(new_single) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) response_json = json.loads(response.data) - housing_status = dpath.util.get(response_json, 'trace/housing_occupancy_status<2017-01>/value') - assert housing_status[0] == 'tenant' # The default value + housing_status = dpath.util.get( + response_json, + "trace/housing_occupancy_status<2017-01>/value", + ) + assert housing_status[0] == "tenant" # The default value -def test_entities_description(test_client): +def test_entities_description(test_client) -> None: simulation_json = json.dumps(couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') - response_json = json.loads(response.data.decode('utf-8')) + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) + response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( - dpath.util.get(response_json, 'entitiesDescription/persons'), - ['Javier', "Alicia"] - ) + dpath.util.get(response_json, "entitiesDescription/persons"), + ["Javier", "Alicia"], + ) -def test_root_nodes(test_client): +def test_root_nodes(test_client) -> None: simulation_json = json.dumps(couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') - response_json = json.loads(response.data.decode('utf-8')) + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) + response_json = json.loads(response.data.decode("utf-8")) assert_items_equal( - dpath.util.get(response_json, 'requestedCalculations'), - ['disposable_income<2017-01>', 'total_benefits<2017-01>', 'total_taxes<2017-01>'] - ) + dpath.util.get(response_json, "requestedCalculations"), + [ + "disposable_income<2017-01>", + "total_benefits<2017-01>", + "total_taxes<2017-01>", + ], + ) -def test_str_variable(test_client): +def test_str_variable(test_client) -> None: new_couple = copy.deepcopy(couple) - new_couple['households']['_']['postal_code'] = {'2017-01': None} + new_couple["households"]["_"]["postal_code"] = {"2017-01": None} simulation_json = json.dumps(new_couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) assert response.status_code == client.OK -def test_trace_parameters(test_client): +def test_trace_parameters(test_client) -> None: new_couple = copy.deepcopy(couple) - new_couple['households']['_']['housing_tax'] = {'2017': None} + new_couple["households"]["_"]["housing_tax"] = {"2017": None} simulation_json = json.dumps(new_couple) - response = test_client.post('/trace', data = simulation_json, content_type = 'application/json') - response_json = json.loads(response.data.decode('utf-8')) - - assert len(dpath.util.get(response_json, 'trace/housing_tax<2017>/parameters')) > 0 - taxes__housing_tax__minimal_amount = dpath.util.get(response_json, 'trace/housing_tax<2017>/parameters/taxes.housing_tax.minimal_amount<2017-01-01>') + response = test_client.post( + "/trace", + data=simulation_json, + content_type="application/json", + ) + response_json = json.loads(response.data.decode("utf-8")) + + assert len(dpath.util.get(response_json, "trace/housing_tax<2017>/parameters")) > 0 + taxes__housing_tax__minimal_amount = dpath.util.get( + response_json, + "trace/housing_tax<2017>/parameters/taxes.housing_tax.minimal_amount<2017-01-01>", + ) assert taxes__housing_tax__minimal_amount == 200 diff --git a/tests/web_api/test_variables.py b/tests/web_api/test_variables.py index 4581608aa8..d3b46dfff9 100644 --- a/tests/web_api/test_variables.py +++ b/tests/web_api/test_variables.py @@ -1,14 +1,15 @@ -from http import client import json -import pytest import re +from http import client + +import pytest -def assert_items_equal(x, y): +def assert_items_equal(x, y) -> None: assert set(x) == set(y) -GITHUB_URL_REGEX = r'^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/variables/(.)+\.py#L\d+-L\d+$' +GITHUB_URL_REGEX = r"^https://github\.com/openfisca/country-template/blob/\d+\.\d+\.\d+((.dev|rc)\d+)?/openfisca_country_template/variables/(.)+\.py#L\d+-L\d+$" # /variables @@ -16,147 +17,154 @@ def assert_items_equal(x, y): @pytest.fixture(scope="module") def variables_response(test_client): - variables_response = test_client.get("/variables") - return variables_response + return test_client.get("/variables") -def test_return_code(variables_response): +def test_return_code(variables_response) -> None: assert variables_response.status_code == client.OK -def test_response_data(variables_response): - variables = json.loads(variables_response.data.decode('utf-8')) - assert variables['birth'] == { - 'description': 'Birth date', - 'href': 'http://localhost/variable/birth' - } +def test_response_data(variables_response) -> None: + variables = json.loads(variables_response.data.decode("utf-8")) + assert variables["birth"] == { + "description": "Birth date", + "href": "http://localhost/variable/birth", + } # /variable/ -def test_error_code_non_existing_variable(test_client): - response = test_client.get('/variable/non_existing_variable') +def test_error_code_non_existing_variable(test_client) -> None: + response = test_client.get("/variable/non_existing_variable") assert response.status_code == client.NOT_FOUND @pytest.fixture(scope="module") def input_variable_response(test_client): - input_variable_response = test_client.get('/variable/birth') - return input_variable_response + return test_client.get("/variable/birth") -def test_return_code_existing_input_variable(input_variable_response): +def test_return_code_existing_input_variable(input_variable_response) -> None: assert input_variable_response.status_code == client.OK -def check_input_variable_value(key, expected_value, input_variable=None): +def check_input_variable_value(key, expected_value, input_variable=None) -> None: assert input_variable[key] == expected_value -@pytest.mark.parametrize("expected_values", [ - ('description', 'Birth date'), - ('valueType', 'Date'), - ('defaultValue', '1970-01-01'), - ('definitionPeriod', 'ETERNITY'), - ('entity', 'person'), - ('references', ['https://en.wiktionary.org/wiki/birthdate']), - ]) -def test_input_variable_value(expected_values, input_variable_response): - input_variable = json.loads(input_variable_response.data.decode('utf-8')) +@pytest.mark.parametrize( + "expected_values", + [ + ("description", "Birth date"), + ("valueType", "Date"), + ("defaultValue", "1970-01-01"), + ("definitionPeriod", "ETERNITY"), + ("entity", "person"), + ("references", ["https://en.wiktionary.org/wiki/birthdate"]), + ], +) +def test_input_variable_value(expected_values, input_variable_response) -> None: + input_variable = json.loads(input_variable_response.data.decode("utf-8")) check_input_variable_value(*expected_values, input_variable=input_variable) -def test_input_variable_github_url(test_client): - input_variable_response = test_client.get('/variable/income_tax') - input_variable = json.loads(input_variable_response.data.decode('utf-8')) +def test_input_variable_github_url(test_client) -> None: + input_variable_response = test_client.get("/variable/income_tax") + input_variable = json.loads(input_variable_response.data.decode("utf-8")) - assert re.match(GITHUB_URL_REGEX, input_variable['source']) + assert re.match(GITHUB_URL_REGEX, input_variable["source"]) -def test_return_code_existing_variable(test_client): - variable_response = test_client.get('/variable/income_tax') +def test_return_code_existing_variable(test_client) -> None: + variable_response = test_client.get("/variable/income_tax") assert variable_response.status_code == client.OK -def check_variable_value(key, expected_value, variable=None): +def check_variable_value(key, expected_value, variable=None) -> None: assert variable[key] == expected_value -@pytest.mark.parametrize("expected_values", [ - ('description', 'Income tax'), - ('valueType', 'Float'), - ('defaultValue', 0), - ('definitionPeriod', 'MONTH'), - ('entity', 'person'), - ]) -def test_variable_value(expected_values, test_client): - variable_response = test_client.get('/variable/income_tax') - variable = json.loads(variable_response.data.decode('utf-8')) +@pytest.mark.parametrize( + "expected_values", + [ + ("description", "Income tax"), + ("valueType", "Float"), + ("defaultValue", 0), + ("definitionPeriod", "MONTH"), + ("entity", "person"), + ], +) +def test_variable_value(expected_values, test_client) -> None: + variable_response = test_client.get("/variable/income_tax") + variable = json.loads(variable_response.data.decode("utf-8")) check_variable_value(*expected_values, variable=variable) -def test_variable_formula_github_link(test_client): - variable_response = test_client.get('/variable/income_tax') - variable = json.loads(variable_response.data.decode('utf-8')) - assert re.match(GITHUB_URL_REGEX, variable['formulas']['0001-01-01']['source']) +def test_variable_formula_github_link(test_client) -> None: + variable_response = test_client.get("/variable/income_tax") + variable = json.loads(variable_response.data.decode("utf-8")) + assert re.match(GITHUB_URL_REGEX, variable["formulas"]["0001-01-01"]["source"]) -def test_variable_formula_content(test_client): - variable_response = test_client.get('/variable/income_tax') - variable = json.loads(variable_response.data.decode('utf-8')) - content = variable['formulas']['0001-01-01']['content'] +def test_variable_formula_content(test_client) -> None: + variable_response = test_client.get("/variable/income_tax") + variable = json.loads(variable_response.data.decode("utf-8")) + content = variable["formulas"]["0001-01-01"]["content"] assert "def formula(person, period, parameters):" in content - assert "return person(\"salary\", period) * parameters(period).taxes.income_tax_rate" in content + assert ( + 'return person("salary", period) * parameters(period).taxes.income_tax_rate' + in content + ) -def test_null_values_are_dropped(test_client): - variable_response = test_client.get('/variable/age') - variable = json.loads(variable_response.data.decode('utf-8')) - assert 'references' not in variable.keys() +def test_null_values_are_dropped(test_client) -> None: + variable_response = test_client.get("/variable/age") + variable = json.loads(variable_response.data.decode("utf-8")) + assert "references" not in variable -def test_variable_with_start_and_stop_date(test_client): - response = test_client.get('/variable/housing_allowance') - variable = json.loads(response.data.decode('utf-8')) - assert_items_equal(variable['formulas'], ['1980-01-01', '2016-12-01']) - assert variable['formulas']['2016-12-01'] is None - assert 'formula' in variable['formulas']['1980-01-01']['content'] +def test_variable_with_start_and_stop_date(test_client) -> None: + response = test_client.get("/variable/housing_allowance") + variable = json.loads(response.data.decode("utf-8")) + assert_items_equal(variable["formulas"], ["1980-01-01", "2016-12-01"]) + assert variable["formulas"]["2016-12-01"] is None + assert "formula" in variable["formulas"]["1980-01-01"]["content"] -def test_variable_with_enum(test_client): - response = test_client.get('/variable/housing_occupancy_status') - variable = json.loads(response.data.decode('utf-8')) - assert variable['valueType'] == 'String' - assert variable['defaultValue'] == 'tenant' - assert 'possibleValues' in variable.keys() - assert variable['possibleValues'] == { - 'free_lodger': 'Free lodger', - 'homeless': 'Homeless', - 'owner': 'Owner', - 'tenant': 'Tenant'} +def test_variable_with_enum(test_client) -> None: + response = test_client.get("/variable/housing_occupancy_status") + variable = json.loads(response.data.decode("utf-8")) + assert variable["valueType"] == "String" + assert variable["defaultValue"] == "tenant" + assert "possibleValues" in variable + assert variable["possibleValues"] == { + "free_lodger": "Free lodger", + "homeless": "Homeless", + "owner": "Owner", + "tenant": "Tenant", + } @pytest.fixture(scope="module") def dated_variable_response(test_client): - dated_variable_response = test_client.get('/variable/basic_income') - return dated_variable_response + return test_client.get("/variable/basic_income") -def test_return_code_existing_dated_variable(dated_variable_response): +def test_return_code_existing_dated_variable(dated_variable_response) -> None: assert dated_variable_response.status_code == client.OK -def test_dated_variable_formulas_dates(dated_variable_response): - dated_variable = json.loads(dated_variable_response.data.decode('utf-8')) - assert_items_equal(dated_variable['formulas'], ['2016-12-01', '2015-12-01']) +def test_dated_variable_formulas_dates(dated_variable_response) -> None: + dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) + assert_items_equal(dated_variable["formulas"], ["2016-12-01", "2015-12-01"]) -def test_dated_variable_formulas_content(dated_variable_response): - dated_variable = json.loads(dated_variable_response.data.decode('utf-8')) - formula_code_2016 = dated_variable['formulas']['2016-12-01']['content'] - formula_code_2015 = dated_variable['formulas']['2015-12-01']['content'] +def test_dated_variable_formulas_content(dated_variable_response) -> None: + dated_variable = json.loads(dated_variable_response.data.decode("utf-8")) + formula_code_2016 = dated_variable["formulas"]["2016-12-01"]["content"] + formula_code_2015 = dated_variable["formulas"]["2015-12-01"]["content"] assert "def formula_2016_12(person, period, parameters):" in formula_code_2016 assert "return" in formula_code_2016 @@ -164,16 +172,22 @@ def test_dated_variable_formulas_content(dated_variable_response): assert "return" in formula_code_2015 -def test_variable_encoding(test_client): - variable_response = test_client.get('/variable/pension') +def test_variable_encoding(test_client) -> None: + variable_response = test_client.get("/variable/pension") assert variable_response.status_code == client.OK -def test_variable_documentation(test_client): - response = test_client.get('/variable/housing_allowance') - variable = json.loads(response.data.decode('utf-8')) - assert variable['documentation'] == "This allowance was introduced on the 1st of Jan 1980.\nIt disappeared in Dec 2016." +def test_variable_documentation(test_client) -> None: + response = test_client.get("/variable/housing_allowance") + variable = json.loads(response.data.decode("utf-8")) + assert ( + variable["documentation"] + == "This allowance was introduced on the 1st of Jan 1980.\nIt disappeared in Dec 2016." + ) - formula_documentation = variable['formulas']['1980-01-01']['documentation'] + formula_documentation = variable["formulas"]["1980-01-01"]["documentation"] assert "Housing allowance." in formula_documentation - assert "Calculating it before this date will always return the variable default value, 0." in formula_documentation + assert ( + "Calculating it before this date will always return the variable default value, 0." + in formula_documentation + )