diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..0a9c156 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,37 @@ +--- +name: Bug report +about: Create a bug report to help us improve +labels: "bug" +--- + +**Describe the bug** + + + +**To Reproduce** + +Steps to reproduce the behavior: + +1. ... +2. ... +3. ... + +**Expected behavior** + + + +**System [please complete the following information]:** + +- OS: e.g. [Ubuntu 18.04] +- Python Version: [e.g. Python 3.8] +- PySpark version: [e.g. PySpark 3.5.1] + +**Additional context** + + + +**Are you planning on creating a PR?** + + + +- [ ] I'm planning to make a pull-request diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..ddac768 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,23 @@ +--- +name: Feature request +about: Suggest a new feature +labels: "enhancement" +--- + +**Is your feature request related to a problem? Please describe.** + + + +**Describe the solution you would like** + + + +**Additional context** + + + +**Are you planning on creating a PR?** + + + +- [ ] I'm planning to make a pull-request diff --git a/.github/actions/setup-poetry-env/action.yml b/.github/actions/setup-poetry-env/action.yml index 474e3c2..7969e6d 100644 --- a/.github/actions/setup-poetry-env/action.yml +++ b/.github/actions/setup-poetry-env/action.yml @@ -14,7 +14,7 @@ inputs: required: false description: "Install the docs dependency group" default: 'false' - + runs: using: "composite" steps: @@ -60,4 +60,4 @@ runs: run: | poetry run python --version poetry run pyspark --version - shell: bash \ No newline at end of file + shell: bash diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..4a1217d --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,10 @@ +**PR Checklist** + +- [ ] A description of the changes is added to the description of this PR. +- [ ] If there is a related issue, make sure it is linked to this PR. +- [ ] If you've fixed a bug or added code that should be tested, add tests! +- [ ] If you've added or modified a feature, documentation in `docs` is updated + +**Description of changes** + + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3acfd05..8739de5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,9 @@ jobs: - name: Check lock file run: poetry lock --check + - name: Run pre-commit hooks + run: poetry run pre-commit run -a + test: runs-on: ubuntu-latest strategy: diff --git a/.github/workflows/mkdocs.yml b/.github/workflows/mkdocs.yml index 4e3a1f3..af11687 100644 --- a/.github/workflows/mkdocs.yml +++ b/.github/workflows/mkdocs.yml @@ -12,12 +12,12 @@ jobs: steps: - uses: actions/checkout@v2 - + - name: Set up the environment uses: ./.github/actions/setup-poetry-env with: with-docs: true - + - name: Setup GH run: | sudo apt update && sudo apt install -y git diff --git a/.gitignore b/.gitignore index 7ce55ef..4dc8f97 100644 --- a/.gitignore +++ b/.gitignore @@ -1,23 +1,106 @@ -build/ -dist/ -chispa.egg-info/ -.cache/ -tmp/ -.idea/ -.DS_Store +.DS_store .python_version +# Emacs +.dir-locals.el + +# VSCode +.vscode + +# Below are sections from https://github.com/github/gitignore/blob/main/Python.gitignore # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ .pytest_cache/ +cover/ -# MKDocs -site +# Translations +*.mo +*.pot -# VSCode -.vscode +# Sphinx documentation +docs/_build/ -# Emacs -.dir-locals.el \ No newline at end of file +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a868e1b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.6.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.5.2" + hooks: + - id: ruff + args: [--exit-non-zero-on-fix] + - id: ruff-format diff --git a/Makefile b/Makefile index 80f6af9..c1f8aed 100644 --- a/Makefile +++ b/Makefile @@ -3,10 +3,25 @@ install: ## Install the Poetry environment @echo "Creating virtual environment using Poetry" @poetry install +.PHONY: check +check: ## Run code quality checks + @echo "Running pre-commit hooks" + @poetry run pre-commit run -a + .PHONY: test test: ## Run unit tests @echo "Running unit tests" - @poetry run pytest tests + @poetry run pytest tests --cov=chispa --cov-report=term + +.PHONY: test-cov-html +test-cov-html: ## Run unit tests and create a coverage report + @echo "Running unit tests and generating HTML report" + @poetry run pytest tests --cov=chispa --cov-report=html + +.PHONY: test-cov-xml +test-cov-xml: ## Run unit tests and create a coverage report in xml format + @echo "Running unit tests and generating XML report" + @poetry run pytest tests --cov=chispa --cov-report=xml .PHONY: build build: clean-build ## Build wheel and sdist files using Poetry diff --git a/NOTICE-binary.md b/NOTICE-binary.md new file mode 100644 index 0000000..11d24b7 --- /dev/null +++ b/NOTICE-binary.md @@ -0,0 +1,142 @@ +# apache spark + Apache Spark + Copyright 2014 and onwards The Apache Software Foundation. + +# findspark + Copyright (c) 2015, Min RK All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, + are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this list + of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + + Neither the name of findspark nor the names of its contributors may be used to endorse + or promote products derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, + INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pre-commit + Copyright (c) 2014 pre-commit dev team: Anthony Sottile, Ken Struys + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + +# prettytable + Copyright (c) 2009-2014 Luke Maurits + All rights reserved. + With contributions from: + * Chris Clark + * Klein Stephane + * John Filleau + * Vladimir Vrzić + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + +# pytest + The MIT License (MIT) + + Copyright (c) 2004 Holger Krekel and others + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in + the Software without restriction, including without limitation the rights to + use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + of the Software, and to permit persons to whom the Software is furnished to do + so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + +# pytest-describe + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, + INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE + FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# pytest-cov + The MIT License + + Copyright (c) 2010 Meme Dough + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. diff --git a/README.md b/README.md index 5660844..0bd1a92 100644 --- a/README.md +++ b/README.md @@ -311,31 +311,51 @@ assert_df_equality(df1, df2, allow_nan_equality=True) ## Customize formatting -*Available in chispa 0.10+*. - You can specify custom formats for the printed error messages as follows: ```python -@dataclass -class MyFormats: - mismatched_rows = ["light_yellow"] - matched_rows = ["cyan", "bold"] - mismatched_cells = ["purple"] - matched_cells = ["blue"] +from chispa import FormattingConfig + +formats = FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) + +assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) +``` + +or similarly: + +```python +from chispa import FormattingConfig, Color, Style + +formats = FormattingConfig( + mismatched_rows={"color": Color.LIGHT_YELLOW}, + matched_rows={"color": Color.CYAN, "style": Style.BOLD}, + mismatched_cells={"color": Color.PURPLE}, + matched_cells={"color": Color.BLUE}, + ) -assert_basic_rows_equality(df1.collect(), df2.collect(), formats=MyFormats()) +assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) ``` You can also define these formats in `conftest.py` and inject them via a fixture: ```python @pytest.fixture() -def my_formats(): - return MyFormats() +def chispa_formats(): + return FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) -def test_shows_assert_basic_rows_equality(my_formats): +def test_shows_assert_basic_rows_equality(chispa_formats): ... - assert_basic_rows_equality(df1.collect(), df2.collect(), formats=my_formats) + assert_basic_rows_equality(df1.collect(), df2.collect(), formats=chispa_formats) ``` ![custom_formats](https://github.com/MrPowers/chispa/blob/main/images/custom_formats.png) @@ -473,12 +493,6 @@ PySpark 2 support will be dropped when chispa 1.x is released. TODO: Need to benchmark these methods vs. the spark-testing-base ones -## Vendored dependencies - -* [PrettyTable](https://github.com/jazzband/prettytable) - -The dependencies are vendored to save you from dependency hell. - ## Developing chispa on your local machine You are encouraged to clone and/or fork this repo. diff --git a/chispa/__init__.py b/chispa/__init__.py index 0c59b1b..cef3d31 100644 --- a/chispa/__init__.py +++ b/chispa/__init__.py @@ -1,23 +1,27 @@ -import sys +from __future__ import annotations + import os +import sys from glob import glob # Add PySpark to the library path based on the value of SPARK_HOME if pyspark is not already in our path try: - from pyspark import context + from pyspark import context # noqa: F401 except ImportError: # We need to add PySpark, try use findspark, or failback to the "manually" find it try: import findspark + findspark.init() except ImportError: try: - spark_home = os.environ['SPARK_HOME'] - sys.path.append(os.path.join(spark_home, 'python')) - py4j_src_zip = glob(os.path.join(spark_home, 'python', 'lib', 'py4j-*-src.zip')) + spark_home = os.environ["SPARK_HOME"] + sys.path.append(os.path.join(spark_home, "python")) + py4j_src_zip = glob(os.path.join(spark_home, "python", "lib", "py4j-*-src.zip")) if len(py4j_src_zip) == 0: - raise ValueError('py4j source archive not found in %s' - % os.path.join(spark_home, 'python', 'lib')) + raise ValueError( + "py4j source archive not found in {}".format(os.path.join(spark_home, "python", "lib")) + ) else: py4j_src_zip = sorted(py4j_src_zip)[::-1] sys.path.append(py4j_src_zip[0]) @@ -25,26 +29,71 @@ print("Can't find Apache Spark. Please set environment variable SPARK_HOME to root of installation!") exit(-1) -from .dataframe_comparer import DataFramesNotEqualError, assert_df_equality, assert_approx_df_equality -from .column_comparer import ColumnsNotEqualError, assert_column_equality, assert_approx_column_equality -from .rows_comparer import assert_basic_rows_equality from chispa.default_formats import DefaultFormats +from chispa.formatting import Color, Format, FormattingConfig, Style + +from .column_comparer import ( + ColumnsNotEqualError, + assert_approx_column_equality, + assert_column_equality, +) +from .dataframe_comparer import ( + DataFramesNotEqualError, + assert_approx_df_equality, + assert_df_equality, +) +from .rows_comparer import assert_basic_rows_equality + + +class Chispa: + def __init__(self, formats: FormattingConfig | None = None, default_output=None): + if not formats: + self.formats = FormattingConfig() + elif isinstance(formats, FormattingConfig): + self.formats = formats + else: + self.formats = FormattingConfig._from_arbitrary_dataclass(formats) -class Chispa(): - def __init__(self, formats=DefaultFormats(), default_output=None): - self.formats = formats self.default_outputs = default_output - def assert_df_equality(self, df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False): + def assert_df_equality( + self, + df1, + df2, + ignore_nullable=False, + transforms=None, + allow_nan_equality=False, + ignore_column_order=False, + ignore_row_order=False, + underline_cells=False, + ignore_metadata=False, + ): return assert_df_equality( - df1, - df2, - ignore_nullable, - transforms, + df1, + df2, + ignore_nullable, + transforms, allow_nan_equality, - ignore_column_order, - ignore_row_order, - underline_cells, + ignore_column_order, + ignore_row_order, + underline_cells, ignore_metadata, - self.formats) \ No newline at end of file + self.formats, + ) + + +__all__ = ( + "DataFramesNotEqualError", + "assert_df_equality", + "assert_approx_df_equality", + "ColumnsNotEqualError", + "assert_column_equality", + "assert_approx_column_equality", + "assert_basic_rows_equality", + "Style", + "Color", + "FormattingConfig", + "Format", + "Chispa", + "DefaultFormats", +) diff --git a/chispa/bcolors.py b/chispa/bcolors.py index bbbb930..7b77215 100644 --- a/chispa/bcolors.py +++ b/chispa/bcolors.py @@ -1,32 +1,35 @@ +from __future__ import annotations + + class bcolors: - NC = '\033[0m' # No Color, reset all - - Bold = '\033[1m' - Underlined = '\033[4m' - Blink = '\033[5m' - Inverted = '\033[7m' - Hidden = '\033[8m' - - Black = '\033[30m' - Red = '\033[31m' - Green = '\033[32m' - Yellow = '\033[33m' - Blue = '\033[34m' - Purple = '\033[35m' - Cyan = '\033[36m' - LightGray = '\033[37m' - DarkGray = '\033[30m' - LightRed = '\033[31m' - LightGreen = '\033[32m' - LightYellow = '\033[93m' - LightBlue = '\033[34m' - LightPurple = '\033[35m' - LightCyan = '\033[36m' - White = '\033[97m' + NC = "\033[0m" # No Color, reset all + + Bold = "\033[1m" + Underlined = "\033[4m" + Blink = "\033[5m" + Inverted = "\033[7m" + Hidden = "\033[8m" + + Black = "\033[30m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Purple = "\033[35m" + Cyan = "\033[36m" + LightGray = "\033[37m" + DarkGray = "\033[30m" + LightRed = "\033[31m" + LightGreen = "\033[32m" + LightYellow = "\033[93m" + LightBlue = "\033[34m" + LightPurple = "\033[35m" + LightCyan = "\033[36m" + White = "\033[97m" # Style - Bold = '\033[1m' - Underline = '\033[4m' + Bold = "\033[1m" + Underline = "\033[4m" def blue(s: str) -> str: diff --git a/chispa/column_comparer.py b/chispa/column_comparer.py index b55f3ef..e6a380c 100644 --- a/chispa/column_comparer.py +++ b/chispa/column_comparer.py @@ -1,10 +1,14 @@ -from chispa.bcolors import * -from chispa.prettytable import PrettyTable +from __future__ import annotations + +from prettytable import PrettyTable + +from chispa.bcolors import bcolors class ColumnsNotEqualError(Exception): - """The columns are not equal""" - pass + """The columns are not equal""" + + pass def assert_column_equality(df, col_name1, col_name2): @@ -35,11 +39,11 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision): first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed # when one is None and the other isn't, they're not equal - if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None): + if (elements[0] is None and elements[1] is not None) or (elements[0] is not None and elements[1] is None): all_rows_equal = False t.add_row([str(elements[0]), str(elements[1])]) # when both are None, they're equal - elif elements[0] == None and elements[1] == None: + elif elements[0] is None and elements[1] is None: t.add_row([first, second]) # when the diff is less than the threshhold, they're approximately equal elif abs(elements[0] - elements[1]) < precision: @@ -48,6 +52,5 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision): else: all_rows_equal = False t.add_row([str(elements[0]), str(elements[1])]) - if all_rows_equal == False: + if all_rows_equal is False: raise ColumnsNotEqualError("\n" + t.get_string()) - diff --git a/chispa/dataframe_comparer.py b/chispa/dataframe_comparer.py index 6b2fc67..167ceb2 100644 --- a/chispa/dataframe_comparer.py +++ b/chispa/dataframe_comparer.py @@ -1,17 +1,39 @@ -from chispa.schema_comparer import assert_schema_equality -from chispa.default_formats import DefaultFormats -from chispa.rows_comparer import assert_basic_rows_equality, assert_generic_rows_equality -from chispa.row_comparer import are_rows_equal_enhanced, are_rows_approx_equal +from __future__ import annotations + from functools import reduce +from chispa.formatting import FormattingConfig +from chispa.row_comparer import are_rows_approx_equal, are_rows_equal_enhanced +from chispa.rows_comparer import ( + assert_basic_rows_equality, + assert_generic_rows_equality, +) +from chispa.schema_comparer import assert_schema_equality + class DataFramesNotEqualError(Exception): - """The DataFrames are not equal""" - pass + """The DataFrames are not equal""" + pass + + +def assert_df_equality( + df1, + df2, + ignore_nullable=False, + transforms=None, + allow_nan_equality=False, + ignore_column_order=False, + ignore_row_order=False, + underline_cells=False, + ignore_metadata=False, + formats: FormattingConfig | None = None, +): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) -def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False, formats=DefaultFormats()): if transforms is None: transforms = [] if ignore_column_order: @@ -23,10 +45,20 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata) if allow_nan_equality: assert_generic_rows_equality( - df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells, formats=formats) + df1.collect(), + df2.collect(), + are_rows_equal_enhanced, + [True], + underline_cells=underline_cells, + formats=formats, + ) else: assert_basic_rows_equality( - df1.collect(), df2.collect(), underline_cells=underline_cells, formats=formats) + df1.collect(), + df2.collect(), + underline_cells=underline_cells, + formats=formats, + ) def are_dfs_equal(df1, df2): @@ -37,8 +69,22 @@ def are_dfs_equal(df1, df2): return True -def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, formats=DefaultFormats()): +def assert_approx_df_equality( + df1, + df2, + precision, + ignore_nullable=False, + transforms=None, + allow_nan_equality=False, + ignore_column_order=False, + ignore_row_order=False, + formats: FormattingConfig | None = None, +): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) + if transforms is None: transforms = [] if ignore_column_order: @@ -49,7 +95,13 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transf df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) assert_schema_equality(df1.schema, df2.schema, ignore_nullable) if precision != 0: - assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality], formats) + assert_generic_rows_equality( + df1.collect(), + df2.collect(), + are_rows_approx_equal, + [precision, allow_nan_equality], + formats, + ) elif allow_nan_equality: assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], formats) else: diff --git a/chispa/default_formats.py b/chispa/default_formats.py index 60654d7..ab1f292 100644 --- a/chispa/default_formats.py +++ b/chispa/default_formats.py @@ -1,8 +1,21 @@ -from dataclasses import dataclass +from __future__ import annotations + +import warnings +from dataclasses import dataclass, field + @dataclass class DefaultFormats: - mismatched_rows = ["red"] - matched_rows = ["blue"] - mismatched_cells = ["red", "underline"] - matched_cells = ["blue"] + """ + This class is now deprecated and should be removed in a future release. + """ + + mismatched_rows: list[str] = field(default_factory=lambda: ["red"]) + matched_rows: list[str] = field(default_factory=lambda: ["blue"]) + mismatched_cells: list[str] = field(default_factory=lambda: ["red", "underline"]) + matched_cells: list[str] = field(default_factory=lambda: ["blue"]) + + def __post_init__(self): + warnings.warn( + "DefaultFormats is deprecated. Use `chispa.formatting.FormattingConfig` instead.", DeprecationWarning + ) diff --git a/chispa/formatting/__init__.py b/chispa/formatting/__init__.py new file mode 100644 index 0000000..6d107b0 --- /dev/null +++ b/chispa/formatting/__init__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from chispa.formatting.format_string import format_string +from chispa.formatting.formats import RESET, Color, Format, Style +from chispa.formatting.formatting_config import FormattingConfig + +__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_string", "RESET") diff --git a/chispa/formatting/format_string.py b/chispa/formatting/format_string.py new file mode 100644 index 0000000..0725744 --- /dev/null +++ b/chispa/formatting/format_string.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from chispa.formatting.formats import RESET, Format + + +def format_string(input_string: str, format: Format) -> str: + if not format.color and not format.style: + return input_string + + formatted_string = input_string + codes = [] + + if format.style: + for style in format.style: + codes.append(style.value) + + if format.color: + codes.append(format.color.value) + + formatted_string = "".join(codes) + formatted_string + RESET + return formatted_string diff --git a/chispa/formatting/formats.py b/chispa/formatting/formats.py new file mode 100644 index 0000000..76064d8 --- /dev/null +++ b/chispa/formatting/formats.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +RESET = "\033[0m" + + +class Color(str, Enum): + """ + Enum for terminal colors. + Each color is represented by its corresponding ANSI escape code. + """ + + BLACK = "\033[30m" + RED = "\033[31m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + BLUE = "\033[34m" + PURPLE = "\033[35m" + CYAN = "\033[36m" + LIGHT_GRAY = "\033[37m" + DARK_GRAY = "\033[90m" + LIGHT_RED = "\033[91m" + LIGHT_GREEN = "\033[92m" + LIGHT_YELLOW = "\033[93m" + LIGHT_BLUE = "\033[94m" + LIGHT_PURPLE = "\033[95m" + LIGHT_CYAN = "\033[96m" + WHITE = "\033[97m" + + +class Style(str, Enum): + """ + Enum for text styles. + Each style is represented by its corresponding ANSI escape code. + """ + + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + BLINK = "\033[5m" + INVERT = "\033[7m" + HIDE = "\033[8m" + + +@dataclass +class Format: + """ + Data class to represent text formatting with color and style. + + Attributes: + color (Color | None): The color for the text. + style (list[Style] | None): A list of styles for the text. + """ + + color: Color | None = None + style: list[Style] | None = None + + @classmethod + def from_dict(cls, format_dict: dict) -> Format: + """ + Create a Format instance from a dictionary. + + Args: + format_dict (dict): A dictionary with keys 'color' and/or 'style'. + """ + if not isinstance(format_dict, dict): + raise ValueError("Input must be a dictionary") + + valid_keys = {"color", "style"} + invalid_keys = set(format_dict) - valid_keys + if invalid_keys: + raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {valid_keys}") + + color = cls._get_color_enum(format_dict.get("color")) + style = format_dict.get("style") + if isinstance(style, str): + styles = [cls._get_style_enum(style)] + elif isinstance(style, list): + styles = [cls._get_style_enum(s) for s in style] + else: + styles = None + + return cls(color=color, style=styles) + + @classmethod + def from_list(cls, values: list[str]) -> Format: + """ + Create a Format instance from a list of strings. + + Args: + values (list[str]): A list of strings representing colors and styles. + """ + if not all(isinstance(value, str) for value in values): + raise ValueError("All elements in the list must be strings") + + color = None + styles = [] + valid_colors = [c.name.lower() for c in Color] + valid_styles = [s.name.lower() for s in Style] + + for value in values: + if value in valid_colors: + color = Color[value.upper()] + elif value in valid_styles: + styles.append(Style[value.upper()]) + else: + raise ValueError( + f"Invalid value: {value}. Valid values are colors: {valid_colors} and styles: {valid_styles}" + ) + + return cls(color=color, style=styles if styles else None) + + @staticmethod + def _get_color_enum(color: Color | str | None) -> Color | None: + if isinstance(color, Color): + return color + elif isinstance(color, str): + try: + return Color[color.upper()] + except KeyError: + valid_colors = [c.name.lower() for c in Color] + raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}") + return None + + @staticmethod + def _get_style_enum(style: Style | str | None) -> Style | None: + if isinstance(style, Style): + return style + elif isinstance(style, str): + try: + return Style[style.upper()] + except KeyError: + valid_styles = [f.name.lower() for f in Style] + raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}") + return None diff --git a/chispa/formatting/formatting_config.py b/chispa/formatting/formatting_config.py new file mode 100644 index 0000000..055428f --- /dev/null +++ b/chispa/formatting/formatting_config.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import warnings +from typing import Any, ClassVar + +from chispa.default_formats import DefaultFormats +from chispa.formatting.formats import Color, Format, Style + + +class FormattingConfig: + """ + Class to manage and parse formatting configurations. + """ + + VALID_KEYS: ClassVar = {"color", "style"} + + def __init__( + self, + mismatched_rows: Format | dict = Format(Color.RED), + matched_rows: Format | dict = Format(Color.BLUE), + mismatched_cells: Format | dict = Format(Color.RED, [Style.UNDERLINE]), + matched_cells: Format | dict = Format(Color.BLUE), + ): + """ + Initializes the FormattingConfig with given or default formatting. + + Each of the arguments can be provided as a `Format` object or a dictionary with the following keys: + - 'color': A string representing a color name, which should be one of the valid colors: + ['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', + 'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', + 'light_purple', 'light_cyan', 'white']. + - 'style': A string or list of strings representing styles, which should be one of the valid styles: + ['bold', 'underline', 'blink', 'invert', 'hide']. + + Args: + mismatched_rows (Format | dict): Format or dictionary for mismatched rows. + matched_rows (Format | dict): Format or dictionary for matched rows. + mismatched_cells (Format | dict): Format or dictionary for mismatched cells. + matched_cells (Format | dict): Format or dictionary for matched cells. + + Raises: + ValueError: If the dictionary contains invalid keys or values. + """ + self.mismatched_rows: Format = self._parse_format(mismatched_rows) + self.matched_rows: Format = self._parse_format(matched_rows) + self.mismatched_cells: Format = self._parse_format(mismatched_cells) + self.matched_cells: Format = self._parse_format(matched_cells) + + def _parse_format(self, format: Format | dict) -> Format: + if isinstance(format, Format): + return format + elif isinstance(format, dict): + return Format.from_dict(format) + raise ValueError("Invalid format type. Must be Format or dict.") + + @classmethod + def _from_arbitrary_dataclass(cls, instance: Any) -> FormattingConfig: + """ + Converts an instance of an arbitrary class with specified fields to a FormattingConfig instance. + This method is purely for backwards compatibility and should be removed in a future release, + together with the `DefaultFormats` class. + """ + + if not isinstance(instance, DefaultFormats): + warnings.warn( + "Using an arbitrary dataclass is deprecated. Use `chispa.formatting.FormattingConfig` instead.", + DeprecationWarning, + ) + + mismatched_rows = Format.from_list(getattr(instance, "mismatched_rows")) + matched_rows = Format.from_list(getattr(instance, "matched_rows")) + mismatched_cells = Format.from_list(getattr(instance, "mismatched_cells")) + matched_cells = Format.from_list(getattr(instance, "matched_cells")) + + return cls( + mismatched_rows=mismatched_rows, + matched_rows=matched_rows, + mismatched_cells=mismatched_cells, + matched_cells=matched_cells, + ) diff --git a/chispa/number_helpers.py b/chispa/number_helpers.py index 9f8713c..a49c84a 100644 --- a/chispa/number_helpers.py +++ b/chispa/number_helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math @@ -13,4 +15,4 @@ def nan_safe_equality(x, y) -> bool: def nan_safe_approx_equality(x, y, precision) -> bool: - return (abs(x-y)<=precision) or (isnan(x) and isnan(y)) \ No newline at end of file + return (abs(x - y) <= precision) or (isnan(x) and isnan(y)) diff --git a/chispa/prettytable.py b/chispa/prettytable.py deleted file mode 100644 index b62358a..0000000 --- a/chispa/prettytable.py +++ /dev/null @@ -1,1972 +0,0 @@ -#!/usr/bin/env python -# -# Copyright (c) 2009-2014, Luke Maurits -# All rights reserved. -# With contributions from: -# * Chris Clark -# * Klein Stephane -# * John Filleau -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * The name of the author may not be used to endorse or promote products -# derived from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. - -import copy -import csv -import itertools -import json -import math -import random -import re -import sys -import textwrap -import unicodedata - -__version__ = "0.7.2" -py3k = sys.version_info[0] >= 3 -if py3k: - unicode = str - basestring = str - itermap = map - iterzip = zip - uni_chr = chr - from html.parser import HTMLParser - from html import escape -else: - itermap = itertools.imap - iterzip = itertools.izip - uni_chr = unichr # noqa: F821 - from HTMLParser import HTMLParser - from cgi import escape - -# hrule styles -FRAME = 0 -ALL = 1 -NONE = 2 -HEADER = 3 - -# Table styles -DEFAULT = 10 -MSWORD_FRIENDLY = 11 -PLAIN_COLUMNS = 12 -MARKDOWN = 13 -RANDOM = 20 - -_re = re.compile(r"\033\[[0-9;]*m") - - -def _get_size(text): - lines = text.split("\n") - height = len(lines) - width = max(_str_block_width(line) for line in lines) - return (width, height) - - -class PrettyTable(object): - def __init__(self, field_names=None, **kwargs): - - """Return a new PrettyTable instance - - Arguments: - - encoding - Unicode encoding scheme used to decode any encoded input - title - optional table title - field_names - list or tuple of field names - fields - list or tuple of field names to include in displays - start - index of first data row to include in output - end - index of last data row to include in output PLUS ONE (list slice style) - header - print a header showing field names (True or False) - header_style - stylisation to apply to field names in header - ("cap", "title", "upper", "lower" or None) - border - print a border around the table (True or False) - hrules - controls printing of horizontal rules after rows. - Allowed values: FRAME, HEADER, ALL, NONE - vrules - controls printing of vertical rules between columns. - Allowed values: FRAME, ALL, NONE - int_format - controls formatting of integer data - float_format - controls formatting of floating point data - min_table_width - minimum desired table width, in characters - max_table_width - maximum desired table width, in characters - padding_width - number of spaces on either side of column data - (only used if left and right paddings are None) - left_padding_width - number of spaces on left hand side of column data - right_padding_width - number of spaces on right hand side of column data - vertical_char - single character string used to draw vertical lines - horizontal_char - single character string used to draw horizontal lines - junction_char - single character string used to draw line junctions - sortby - name of field to sort rows by - sort_key - sorting key function, applied to data points before sorting - valign - default valign for each row (None, "t", "m" or "b") - reversesort - True or False to sort in descending or ascending order - oldsortslice - Slice rows before sorting in the "old style" """ - - self.encoding = kwargs.get("encoding", "UTF-8") - - # Data - self._field_names = [] - self._rows = [] - self.align = {} - self.valign = {} - self.max_width = {} - self.min_width = {} - self.int_format = {} - self.float_format = {} - if field_names: - self.field_names = field_names - else: - self._widths = [] - - # Options - self._options = ( - "title start end fields header border sortby reversesort " - "sort_key attributes format hrules vrules".split() - ) - self._options.extend( - "int_format float_format min_table_width max_table_width padding_width " - "left_padding_width right_padding_width".split() - ) - self._options.extend( - "vertical_char horizontal_char junction_char header_style valign xhtml " - "print_empty oldsortslice".split() - ) - self._options.extend("align valign max_width min_width".split()) - for option in self._options: - if option in kwargs: - self._validate_option(option, kwargs[option]) - else: - kwargs[option] = None - - self._title = kwargs["title"] or None - self._start = kwargs["start"] or 0 - self._end = kwargs["end"] or None - self._fields = kwargs["fields"] or None - - if kwargs["header"] in (True, False): - self._header = kwargs["header"] - else: - self._header = True - self._header_style = kwargs["header_style"] or None - if kwargs["border"] in (True, False): - self._border = kwargs["border"] - else: - self._border = True - self._hrules = kwargs["hrules"] or FRAME - self._vrules = kwargs["vrules"] or ALL - - self._sortby = kwargs["sortby"] or None - if kwargs["reversesort"] in (True, False): - self._reversesort = kwargs["reversesort"] - else: - self._reversesort = False - self._sort_key = kwargs["sort_key"] or (lambda x: x) - - # Column specific arguments, use property.setters - self.align = kwargs["align"] or {} - self.valign = kwargs["valign"] or {} - self.max_width = kwargs["max_width"] or {} - self.min_width = kwargs["min_width"] or {} - self.int_format = kwargs["int_format"] or {} - self.float_format = kwargs["float_format"] or {} - - self._min_table_width = kwargs["min_table_width"] or None - self._max_table_width = kwargs["max_table_width"] or None - self._padding_width = kwargs["padding_width"] or 1 - self._left_padding_width = kwargs["left_padding_width"] or None - self._right_padding_width = kwargs["right_padding_width"] or None - - self._vertical_char = kwargs["vertical_char"] or self._unicode("|") - self._horizontal_char = kwargs["horizontal_char"] or self._unicode("-") - self._junction_char = kwargs["junction_char"] or self._unicode("+") - - if kwargs["print_empty"] in (True, False): - self._print_empty = kwargs["print_empty"] - else: - self._print_empty = True - if kwargs["oldsortslice"] in (True, False): - self._oldsortslice = kwargs["oldsortslice"] - else: - self._oldsortslice = False - self._format = kwargs["format"] or False - self._xhtml = kwargs["xhtml"] or False - self._attributes = kwargs["attributes"] or {} - - def _unicode(self, value): - if not isinstance(value, basestring): - value = str(value) - if not isinstance(value, unicode): - value = unicode(value, self.encoding, "strict") - return value - - def _justify(self, text, width, align): - excess = width - _str_block_width(text) - if align == "l": - return text + excess * " " - elif align == "r": - return excess * " " + text - else: - if excess % 2: - # Uneven padding - # Put more space on right if text is of odd length... - if _str_block_width(text) % 2: - return (excess // 2) * " " + text + (excess // 2 + 1) * " " - # and more space on left if text is of even length - else: - return (excess // 2 + 1) * " " + text + (excess // 2) * " " - # Why distribute extra space this way? To match the behaviour of - # the inbuilt str.center() method. - else: - # Equal padding on either side - return (excess // 2) * " " + text + (excess // 2) * " " - - def __getattr__(self, name): - - if name == "rowcount": - return len(self._rows) - elif name == "colcount": - if self._field_names: - return len(self._field_names) - elif self._rows: - return len(self._rows[0]) - else: - return 0 - else: - raise AttributeError(name) - - def __getitem__(self, index): - - new = PrettyTable() - new.field_names = self.field_names - for attr in self._options: - setattr(new, "_" + attr, getattr(self, "_" + attr)) - setattr(new, "_align", getattr(self, "_align")) - if isinstance(index, slice): - for row in self._rows[index]: - new.add_row(row) - elif isinstance(index, int): - new.add_row(self._rows[index]) - else: - raise Exception( - "Index %s is invalid, must be an integer or slice" % str(index) - ) - return new - - if py3k: - - def __str__(self): - return self.__unicode__() - - else: - - def __str__(self): - return self.__unicode__().encode(self.encoding) - - def __unicode__(self): - return self.get_string() - - ############################## - # ATTRIBUTE VALIDATORS # - ############################## - - # The method _validate_option is all that should be used elsewhere in the code base - # to validate options. It will call the appropriate validation method for that - # option. The individual validation methods should never need to be called directly - # (although nothing bad will happen if they *are*). - # Validation happens in TWO places. - # Firstly, in the property setters defined in the ATTRIBUTE MANAGEMENT section. - # Secondly, in the _get_options method, where keyword arguments are mixed with - # persistent settings - - def _validate_option(self, option, val): - if option in ("field_names"): - self._validate_field_names(val) - elif option in ( - "start", - "end", - "max_width", - "min_width", - "min_table_width", - "max_table_width", - "padding_width", - "left_padding_width", - "right_padding_width", - "format", - ): - self._validate_nonnegative_int(option, val) - elif option in ("sortby"): - self._validate_field_name(option, val) - elif option in ("sort_key"): - self._validate_function(option, val) - elif option in ("hrules"): - self._validate_hrules(option, val) - elif option in ("vrules"): - self._validate_vrules(option, val) - elif option in ("fields"): - self._validate_all_field_names(option, val) - elif option in ( - "header", - "border", - "reversesort", - "xhtml", - "print_empty", - "oldsortslice", - ): - self._validate_true_or_false(option, val) - elif option in ("header_style"): - self._validate_header_style(val) - elif option in ("int_format"): - self._validate_int_format(option, val) - elif option in ("float_format"): - self._validate_float_format(option, val) - elif option in ("vertical_char", "horizontal_char", "junction_char"): - self._validate_single_char(option, val) - elif option in ("attributes"): - self._validate_attributes(option, val) - - def _validate_field_names(self, val): - # Check for appropriate length - if self._field_names: - try: - assert len(val) == len(self._field_names) - except AssertionError: - raise Exception( - "Field name list has incorrect number of values, " - "(actual) %d!=%d (expected)" % (len(val), len(self._field_names)) - ) - if self._rows: - try: - assert len(val) == len(self._rows[0]) - except AssertionError: - raise Exception( - "Field name list has incorrect number of values, " - "(actual) %d!=%d (expected)" % (len(val), len(self._rows[0])) - ) - # Check for uniqueness - try: - assert len(val) == len(set(val)) - except AssertionError: - raise Exception("Field names must be unique!") - - def _validate_header_style(self, val): - try: - assert val in ("cap", "title", "upper", "lower", None) - except AssertionError: - raise Exception( - "Invalid header style, use cap, title, upper, lower or None!" - ) - - def _validate_align(self, val): - try: - assert val in ["l", "c", "r"] - except AssertionError: - raise Exception("Alignment %s is invalid, use l, c or r!" % val) - - def _validate_valign(self, val): - try: - assert val in ["t", "m", "b", None] - except AssertionError: - raise Exception("Alignment %s is invalid, use t, m, b or None!" % val) - - def _validate_nonnegative_int(self, name, val): - try: - assert int(val) >= 0 - except AssertionError: - raise Exception( - "Invalid value for {}: {}!".format(name, self._unicode(val)) - ) - - def _validate_true_or_false(self, name, val): - try: - assert val in (True, False) - except AssertionError: - raise Exception("Invalid value for %s! Must be True or False." % name) - - def _validate_int_format(self, name, val): - if val == "": - return - try: - assert type(val) in (str, unicode) - assert val.isdigit() - except AssertionError: - raise Exception( - "Invalid value for %s! Must be an integer format string." % name - ) - - def _validate_float_format(self, name, val): - if val == "": - return - try: - assert type(val) in (str, unicode) - assert "." in val - bits = val.split(".") - assert len(bits) <= 2 - assert bits[0] == "" or bits[0].isdigit() - assert ( - bits[1] == "" - or bits[1].isdigit() - or (bits[1][-1] == "f" and bits[1].rstrip("f").isdigit()) - ) - except AssertionError: - raise Exception( - "Invalid value for %s! Must be a float format string." % name - ) - - def _validate_function(self, name, val): - try: - assert hasattr(val, "__call__") - except AssertionError: - raise Exception("Invalid value for %s! Must be a function." % name) - - def _validate_hrules(self, name, val): - try: - assert val in (ALL, FRAME, HEADER, NONE) - except AssertionError: - raise Exception( - "Invalid value for %s! Must be ALL, FRAME, HEADER or NONE." % name - ) - - def _validate_vrules(self, name, val): - try: - assert val in (ALL, FRAME, NONE) - except AssertionError: - raise Exception( - "Invalid value for %s! Must be ALL, FRAME, or NONE." % name - ) - - def _validate_field_name(self, name, val): - try: - assert (val in self._field_names) or (val is None) - except AssertionError: - raise Exception("Invalid field name: %s!" % val) - - def _validate_all_field_names(self, name, val): - try: - for x in val: - self._validate_field_name(name, x) - except AssertionError: - raise Exception("fields must be a sequence of field names!") - - def _validate_single_char(self, name, val): - try: - assert _str_block_width(val) == 1 - except AssertionError: - raise Exception( - "Invalid value for %s! Must be a string of length 1." % name - ) - - def _validate_attributes(self, name, val): - try: - assert isinstance(val, dict) - except AssertionError: - raise Exception("attributes must be a dictionary of name/value pairs!") - - ############################## - # ATTRIBUTE MANAGEMENT # - ############################## - - @property - def field_names(self): - """List or tuple of field names""" - return self._field_names - - @field_names.setter - def field_names(self, val): - val = [self._unicode(x) for x in val] - self._validate_option("field_names", val) - old_names = None - if self._field_names: - old_names = self._field_names[:] - self._field_names = val - if self._align and old_names: - for old_name, new_name in zip(old_names, val): - self._align[new_name] = self._align[old_name] - for old_name in old_names: - if old_name not in self._align: - self._align.pop(old_name) - else: - self.align = "c" - if self._valign and old_names: - for old_name, new_name in zip(old_names, val): - self._valign[new_name] = self._valign[old_name] - for old_name in old_names: - if old_name not in self._valign: - self._valign.pop(old_name) - else: - self.valign = "t" - - @property - def align(self): - """Controls alignment of fields - Arguments: - - align - alignment, one of "l", "c", or "r" """ - return self._align - - @align.setter - def align(self, val): - if not self._field_names: - self._align = {} - elif val is None or (isinstance(val, dict) and len(val) == 0): - for field in self._field_names: - self._align[field] = "c" - else: - self._validate_align(val) - for field in self._field_names: - self._align[field] = val - - @property - def valign(self): - """Controls vertical alignment of fields - Arguments: - - valign - vertical alignment, one of "t", "m", or "b" """ - return self._valign - - @valign.setter - def valign(self, val): - if not self._field_names: - self._valign = {} - elif val is None or (isinstance(val, dict) and len(val) == 0): - for field in self._field_names: - self._valign[field] = "t" - else: - self._validate_valign(val) - for field in self._field_names: - self._valign[field] = val - - @property - def max_width(self): - """Controls maximum width of fields - Arguments: - - max_width - maximum width integer""" - return self._max_width - - @max_width.setter - def max_width(self, val): - if val is None or (isinstance(val, dict) and len(val) == 0): - self._max_width = {} - else: - self._validate_option("max_width", val) - for field in self._field_names: - self._max_width[field] = val - - @property - def min_width(self): - """Controls minimum width of fields - Arguments: - - min_width - minimum width integer""" - return self._min_width - - @min_width.setter - def min_width(self, val): - if val is None or (isinstance(val, dict) and len(val) == 0): - self._min_width = {} - else: - self._validate_option("min_width", val) - for field in self._field_names: - self._min_width[field] = val - - @property - def min_table_width(self): - return self._min_table_width - - @min_table_width.setter - def min_table_width(self, val): - self._validate_option("min_table_width", val) - self._min_table_width = val - - @property - def max_table_width(self): - return self._max_table_width - - @max_table_width.setter - def max_table_width(self, val): - self._validate_option("max_table_width", val) - self._max_table_width = val - - @property - def fields(self): - """List or tuple of field names to include in displays""" - return self._fields - - @fields.setter - def fields(self, val): - self._validate_option("fields", val) - self._fields = val - - @property - def title(self): - """Optional table title - - Arguments: - - title - table title""" - return self._title - - @title.setter - def title(self, val): - self._title = self._unicode(val) - - @property - def start(self): - """Start index of the range of rows to print - - Arguments: - - start - index of first data row to include in output""" - return self._start - - @start.setter - def start(self, val): - self._validate_option("start", val) - self._start = val - - @property - def end(self): - """End index of the range of rows to print - - Arguments: - - end - index of last data row to include in output PLUS ONE (list slice style)""" - return self._end - - @end.setter - def end(self, val): - self._validate_option("end", val) - self._end = val - - @property - def sortby(self): - """Name of field by which to sort rows - - Arguments: - - sortby - field name to sort by""" - return self._sortby - - @sortby.setter - def sortby(self, val): - self._validate_option("sortby", val) - self._sortby = val - - @property - def reversesort(self): - """Controls direction of sorting (ascending vs descending) - - Arguments: - - reveresort - set to True to sort by descending order, or False to sort by - ascending order""" - return self._reversesort - - @reversesort.setter - def reversesort(self, val): - self._validate_option("reversesort", val) - self._reversesort = val - - @property - def sort_key(self): - """Sorting key function, applied to data points before sorting - - Arguments: - - sort_key - a function which takes one argument and returns something to be - sorted""" - return self._sort_key - - @sort_key.setter - def sort_key(self, val): - self._validate_option("sort_key", val) - self._sort_key = val - - @property - def header(self): - """Controls printing of table header with field names - - Arguments: - - header - print a header showing field names (True or False)""" - return self._header - - @header.setter - def header(self, val): - self._validate_option("header", val) - self._header = val - - @property - def header_style(self): - """Controls stylisation applied to field names in header - - Arguments: - - header_style - stylisation to apply to field names in header - ("cap", "title", "upper", "lower" or None)""" - return self._header_style - - @header_style.setter - def header_style(self, val): - self._validate_header_style(val) - self._header_style = val - - @property - def border(self): - """Controls printing of border around table - - Arguments: - - border - print a border around the table (True or False)""" - return self._border - - @border.setter - def border(self, val): - self._validate_option("border", val) - self._border = val - - @property - def hrules(self): - """Controls printing of horizontal rules after rows - - Arguments: - - hrules - horizontal rules style. Allowed values: FRAME, ALL, HEADER, NONE""" - return self._hrules - - @hrules.setter - def hrules(self, val): - self._validate_option("hrules", val) - self._hrules = val - - @property - def vrules(self): - """Controls printing of vertical rules between columns - - Arguments: - - vrules - vertical rules style. Allowed values: FRAME, ALL, NONE""" - return self._vrules - - @vrules.setter - def vrules(self, val): - self._validate_option("vrules", val) - self._vrules = val - - @property - def int_format(self): - """Controls formatting of integer data - Arguments: - - int_format - integer format string""" - return self._int_format - - @int_format.setter - def int_format(self, val): - if val is None or (isinstance(val, dict) and len(val) == 0): - self._int_format = {} - else: - self._validate_option("int_format", val) - for field in self._field_names: - self._int_format[field] = val - - @property - def float_format(self): - """Controls formatting of floating point data - Arguments: - - float_format - floating point format string""" - return self._float_format - - @float_format.setter - def float_format(self, val): - if val is None or (isinstance(val, dict) and len(val) == 0): - self._float_format = {} - else: - self._validate_option("float_format", val) - for field in self._field_names: - self._float_format[field] = val - - @property - def padding_width(self): - """The number of empty spaces between a column's edge and its content - - Arguments: - - padding_width - number of spaces, must be a positive integer""" - return self._padding_width - - @padding_width.setter - def padding_width(self, val): - self._validate_option("padding_width", val) - self._padding_width = val - - @property - def left_padding_width(self): - """The number of empty spaces between a column's left edge and its content - - Arguments: - - left_padding - number of spaces, must be a positive integer""" - return self._left_padding_width - - @left_padding_width.setter - def left_padding_width(self, val): - self._validate_option("left_padding_width", val) - self._left_padding_width = val - - @property - def right_padding_width(self): - """The number of empty spaces between a column's right edge and its content - - Arguments: - - right_padding - number of spaces, must be a positive integer""" - return self._right_padding_width - - @right_padding_width.setter - def right_padding_width(self, val): - self._validate_option("right_padding_width", val) - self._right_padding_width = val - - @property - def vertical_char(self): - """The charcter used when printing table borders to draw vertical lines - - Arguments: - - vertical_char - single character string used to draw vertical lines""" - return self._vertical_char - - @vertical_char.setter - def vertical_char(self, val): - val = self._unicode(val) - self._validate_option("vertical_char", val) - self._vertical_char = val - - @property - def horizontal_char(self): - """The charcter used when printing table borders to draw horizontal lines - - Arguments: - - horizontal_char - single character string used to draw horizontal lines""" - return self._horizontal_char - - @horizontal_char.setter - def horizontal_char(self, val): - val = self._unicode(val) - self._validate_option("horizontal_char", val) - self._horizontal_char = val - - @property - def junction_char(self): - """The charcter used when printing table borders to draw line junctions - - Arguments: - - junction_char - single character string used to draw line junctions""" - return self._junction_char - - @junction_char.setter - def junction_char(self, val): - val = self._unicode(val) - self._validate_option("vertical_char", val) - self._junction_char = val - - @property - def format(self): - """Controls whether or not HTML tables are formatted to match styling options - - Arguments: - - format - True or False""" - return self._format - - @format.setter - def format(self, val): - self._validate_option("format", val) - self._format = val - - @property - def print_empty(self): - """Controls whether or not empty tables produce a header and frame or just an - empty string - - Arguments: - - print_empty - True or False""" - return self._print_empty - - @print_empty.setter - def print_empty(self, val): - self._validate_option("print_empty", val) - self._print_empty = val - - @property - def attributes(self): - """A dictionary of HTML attribute name/value pairs to be included in the - tag when printing HTML - - Arguments: - - attributes - dictionary of attributes""" - return self._attributes - - @attributes.setter - def attributes(self, val): - self._validate_option("attributes", val) - self._attributes = val - - @property - def oldsortslice(self): - """ oldsortslice - Slice rows before sorting in the "old style" """ - return self._oldsortslice - - @oldsortslice.setter - def oldsortslice(self, val): - self._validate_option("oldsortslice", val) - self._oldsortslice = val - - ############################## - # OPTION MIXER # - ############################## - - def _get_options(self, kwargs): - - options = {} - for option in self._options: - if option in kwargs: - self._validate_option(option, kwargs[option]) - options[option] = kwargs[option] - else: - options[option] = getattr(self, "_" + option) - return options - - ############################## - # PRESET STYLE LOGIC # - ############################## - - def set_style(self, style): - - if style == DEFAULT: - self._set_default_style() - elif style == MSWORD_FRIENDLY: - self._set_msword_style() - elif style == PLAIN_COLUMNS: - self._set_columns_style() - elif style == MARKDOWN: - self._set_markdown_style() - elif style == RANDOM: - self._set_random_style() - else: - raise Exception("Invalid pre-set style!") - - def _set_markdown_style(self): - self.header = True - self.border = True - self._hrules = None - self.padding_width = 1 - self.left_padding_width = 1 - self.right_padding_width = 1 - self.vertical_char = "|" - self.junction_char = "|" - - def _set_default_style(self): - - self.header = True - self.border = True - self._hrules = FRAME - self._vrules = ALL - self.padding_width = 1 - self.left_padding_width = 1 - self.right_padding_width = 1 - self.vertical_char = "|" - self.horizontal_char = "-" - self.junction_char = "+" - - def _set_msword_style(self): - - self.header = True - self.border = True - self._hrules = NONE - self.padding_width = 1 - self.left_padding_width = 1 - self.right_padding_width = 1 - self.vertical_char = "|" - - def _set_columns_style(self): - - self.header = True - self.border = False - self.padding_width = 1 - self.left_padding_width = 0 - self.right_padding_width = 8 - - def _set_random_style(self): - - # Just for fun! - self.header = random.choice((True, False)) - self.border = random.choice((True, False)) - self._hrules = random.choice((ALL, FRAME, HEADER, NONE)) - self._vrules = random.choice((ALL, FRAME, NONE)) - self.left_padding_width = random.randint(0, 5) - self.right_padding_width = random.randint(0, 5) - self.vertical_char = random.choice(r"~!@#$%^&*()_+|-=\{}[];':\",./;<>?") - self.horizontal_char = random.choice(r"~!@#$%^&*()_+|-=\{}[];':\",./;<>?") - self.junction_char = random.choice(r"~!@#$%^&*()_+|-=\{}[];':\",./;<>?") - - ############################## - # DATA INPUT METHODS # - ############################## - - def add_row(self, row): - - """Add a row to the table - - Arguments: - - row - row of data, should be a list with as many elements as the table - has fields""" - - if self._field_names and len(row) != len(self._field_names): - raise Exception( - "Row has incorrect number of values, (actual) %d!=%d (expected)" - % (len(row), len(self._field_names)) - ) - if not self._field_names: - self.field_names = [("Field %d" % (n + 1)) for n in range(0, len(row))] - self._rows.append(list(row)) - - def del_row(self, row_index): - - """Delete a row to the table - - Arguments: - - row_index - The index of the row you want to delete. Indexing starts at 0.""" - - if row_index > len(self._rows) - 1: - raise Exception( - "Cant delete row at index %d, table only has %d rows!" - % (row_index, len(self._rows)) - ) - del self._rows[row_index] - - def add_column(self, fieldname, column, align="c", valign="t"): - - """Add a column to the table. - - Arguments: - - fieldname - name of the field to contain the new column of data - column - column of data, should be a list with as many elements as the - table has rows - align - desired alignment for this column - "l" for left, "c" for centre and - "r" for right - valign - desired vertical alignment for new columns - "t" for top, - "m" for middle and "b" for bottom""" - - if len(self._rows) in (0, len(column)): - self._validate_align(align) - self._validate_valign(valign) - self._field_names.append(fieldname) - self._align[fieldname] = align - self._valign[fieldname] = valign - for i in range(0, len(column)): - if len(self._rows) < i + 1: - self._rows.append([]) - self._rows[i].append(column[i]) - else: - raise Exception( - "Column length %d does not match number of rows %d!" - % (len(column), len(self._rows)) - ) - - def clear_rows(self): - - """Delete all rows from the table but keep the current field names""" - - self._rows = [] - - def clear(self): - - """Delete all rows and field names from the table, maintaining nothing but - styling options""" - - self._rows = [] - self._field_names = [] - self._widths = [] - - ############################## - # MISC PUBLIC METHODS # - ############################## - - def copy(self): - return copy.deepcopy(self) - - ############################## - # MISC PRIVATE METHODS # - ############################## - - def _format_value(self, field, value): - if isinstance(value, int) and field in self._int_format: - value = self._unicode(("%%%sd" % self._int_format[field]) % value) - elif isinstance(value, float) and field in self._float_format: - value = self._unicode(("%%%sf" % self._float_format[field]) % value) - return self._unicode(value) - - def _compute_table_width(self, options): - table_width = 2 if options["vrules"] in (FRAME, ALL) else 0 - per_col_padding = sum(self._get_padding_widths(options)) - for index, fieldname in enumerate(self.field_names): - if not options["fields"] or ( - options["fields"] and fieldname in options["fields"] - ): - table_width += self._widths[index] + per_col_padding - return table_width - - def _compute_widths(self, rows, options): - if options["header"]: - widths = [_get_size(field)[0] for field in self._field_names] - else: - widths = len(self.field_names) * [0] - - for row in rows: - for index, value in enumerate(row): - fieldname = self.field_names[index] - if fieldname in self.max_width: - widths[index] = max( - widths[index], - min(_get_size(value)[0], self.max_width[fieldname]), - ) - else: - widths[index] = max(widths[index], _get_size(value)[0]) - if fieldname in self.min_width: - widths[index] = max(widths[index], self.min_width[fieldname]) - self._widths = widths - - # Are we exceeding max_table_width? - if self._max_table_width: - table_width = self._compute_table_width(options) - if table_width > self._max_table_width: - # Shrink widths in proportion - scale = 1.0 * self._max_table_width / table_width - widths = [int(math.floor(w * scale)) for w in widths] - self._widths = widths - - # Are we under min_table_width or title width? - if self._min_table_width or options["title"]: - if options["title"]: - title_width = len(options["title"]) + sum( - self._get_padding_widths(options) - ) - if options["vrules"] in (FRAME, ALL): - title_width += 2 - else: - title_width = 0 - min_table_width = self.min_table_width or 0 - min_width = max(title_width, min_table_width) - table_width = self._compute_table_width(options) - if table_width < min_width: - # Grow widths in proportion - scale = 1.0 * min_width / table_width - widths = [int(math.ceil(w * scale)) for w in widths] - self._widths = widths - - def _get_padding_widths(self, options): - - if options["left_padding_width"] is not None: - lpad = options["left_padding_width"] - else: - lpad = options["padding_width"] - if options["right_padding_width"] is not None: - rpad = options["right_padding_width"] - else: - rpad = options["padding_width"] - return lpad, rpad - - def _get_rows(self, options): - """Return only those data rows that should be printed, based on slicing and - sorting. - - Arguments: - - options - dictionary of option settings.""" - - if options["oldsortslice"]: - rows = copy.deepcopy(self._rows[options["start"] : options["end"]]) - else: - rows = copy.deepcopy(self._rows) - - # Sort - if options["sortby"]: - sortindex = self._field_names.index(options["sortby"]) - # Decorate - rows = [[row[sortindex]] + row for row in rows] - # Sort - rows.sort(reverse=options["reversesort"], key=options["sort_key"]) - # Undecorate - rows = [row[1:] for row in rows] - - # Slice if necessary - if not options["oldsortslice"]: - rows = rows[options["start"] : options["end"]] - - return rows - - def _format_row(self, row, options): - return [ - self._format_value(field, value) - for (field, value) in zip(self._field_names, row) - ] - - def _format_rows(self, rows, options): - return [self._format_row(row, options) for row in rows] - - ############################## - # PLAIN TEXT STRING METHODS # - ############################## - - def get_string(self, **kwargs): - - """Return string representation of table in current state. - - Arguments: - - title - optional table title - start - index of first data row to include in output - end - index of last data row to include in output PLUS ONE (list slice style) - fields - names of fields (columns) to include - header - print a header showing field names (True or False) - border - print a border around the table (True or False) - hrules - controls printing of horizontal rules after rows. - Allowed values: ALL, FRAME, HEADER, NONE - vrules - controls printing of vertical rules between columns. - Allowed values: FRAME, ALL, NONE - int_format - controls formatting of integer data - float_format - controls formatting of floating point data - padding_width - number of spaces on either side of column data (only used if - left and right paddings are None) - left_padding_width - number of spaces on left hand side of column data - right_padding_width - number of spaces on right hand side of column data - vertical_char - single character string used to draw vertical lines - horizontal_char - single character string used to draw horizontal lines - junction_char - single character string used to draw line junctions - sortby - name of field to sort rows by - sort_key - sorting key function, applied to data points before sorting - reversesort - True or False to sort in descending or ascending order - print empty - if True, stringify just the header for an empty table, - if False return an empty string """ - - options = self._get_options(kwargs) - - lines = [] - - # Don't think too hard about an empty table - # Is this the desired behaviour? Maybe we should still print the header? - if self.rowcount == 0 and (not options["print_empty"] or not options["border"]): - return "" - - # Get the rows we need to print, taking into account slicing, sorting, etc. - rows = self._get_rows(options) - - # Turn all data in all rows into Unicode, formatted as desired - formatted_rows = self._format_rows(rows, options) - - # Compute column widths - self._compute_widths(formatted_rows, options) - self._hrule = self._stringify_hrule(options) - - # Add title - title = options["title"] or self._title - if title: - lines.append(self._stringify_title(title, options)) - - # Add header or top of border - if options["header"]: - lines.append(self._stringify_header(options)) - elif options["border"] and options["hrules"] in (ALL, FRAME): - lines.append(self._hrule) - - # Add rows - for row in formatted_rows: - lines.append(self._stringify_row(row, options)) - - # Add bottom of border - if options["border"] and options["hrules"] == FRAME: - lines.append(self._hrule) - - return self._unicode("\n").join(lines) - - def _stringify_hrule(self, options): - - if not options["border"]: - return "" - lpad, rpad = self._get_padding_widths(options) - if options["vrules"] in (ALL, FRAME): - bits = [options["junction_char"]] - else: - bits = [options["horizontal_char"]] - # For tables with no data or fieldnames - if not self._field_names: - bits.append(options["junction_char"]) - return "".join(bits) - for field, width in zip(self._field_names, self._widths): - if options["fields"] and field not in options["fields"]: - continue - bits.append((width + lpad + rpad) * options["horizontal_char"]) - if options["vrules"] == ALL: - bits.append(options["junction_char"]) - else: - bits.append(options["horizontal_char"]) - if options["vrules"] == FRAME: - bits.pop() - bits.append(options["junction_char"]) - return "".join(bits) - - def _stringify_title(self, title, options): - - lines = [] - lpad, rpad = self._get_padding_widths(options) - if options["border"]: - if options["vrules"] == ALL: - options["vrules"] = FRAME - lines.append(self._stringify_hrule(options)) - options["vrules"] = ALL - elif options["vrules"] == FRAME: - lines.append(self._stringify_hrule(options)) - bits = [] - endpoint = ( - options["vertical_char"] if options["vrules"] in (ALL, FRAME) else " " - ) - bits.append(endpoint) - title = " " * lpad + title + " " * rpad - bits.append(self._justify(title, len(self._hrule) - 2, "c")) - bits.append(endpoint) - lines.append("".join(bits)) - return "\n".join(lines) - - def _stringify_header(self, options): - - bits = [] - lpad, rpad = self._get_padding_widths(options) - if options["border"]: - if options["hrules"] in (ALL, FRAME): - bits.append(self._hrule) - bits.append("\n") - if options["vrules"] in (ALL, FRAME): - bits.append(options["vertical_char"]) - else: - bits.append(" ") - # For tables with no data or field names - if not self._field_names: - if options["vrules"] in (ALL, FRAME): - bits.append(options["vertical_char"]) - else: - bits.append(" ") - for field, width, in zip(self._field_names, self._widths): - if options["fields"] and field not in options["fields"]: - continue - if self._header_style == "cap": - fieldname = field.capitalize() - elif self._header_style == "title": - fieldname = field.title() - elif self._header_style == "upper": - fieldname = field.upper() - elif self._header_style == "lower": - fieldname = field.lower() - else: - fieldname = field - bits.append( - " " * lpad - + self._justify(fieldname, width, self._align[field]) - + " " * rpad - ) - if options["border"]: - if options["vrules"] == ALL: - bits.append(options["vertical_char"]) - else: - bits.append(" ") - # If vrules is FRAME, then we just appended a space at the end - # of the last field, when we really want a vertical character - if options["border"] and options["vrules"] == FRAME: - bits.pop() - bits.append(options["vertical_char"]) - if options["border"] and options["hrules"] != NONE: - bits.append("\n") - bits.append(self._hrule) - return "".join(bits) - - def _stringify_row(self, row, options): - - for index, field, value, width, in zip( - range(0, len(row)), self._field_names, row, self._widths - ): - # Enforce max widths - lines = value.split("\n") - new_lines = [] - for line in lines: - if _str_block_width(line) > width: - line = textwrap.fill(line, width) - new_lines.append(line) - lines = new_lines - value = "\n".join(lines) - row[index] = value - - row_height = 0 - for c in row: - h = _get_size(c)[1] - if h > row_height: - row_height = h - - bits = [] - lpad, rpad = self._get_padding_widths(options) - for y in range(0, row_height): - bits.append([]) - if options["border"]: - if options["vrules"] in (ALL, FRAME): - bits[y].append(self.vertical_char) - else: - bits[y].append(" ") - - for field, value, width, in zip(self._field_names, row, self._widths): - - valign = self._valign[field] - lines = value.split("\n") - dHeight = row_height - len(lines) - if dHeight: - if valign == "m": - lines = ( - [""] * int(dHeight / 2) - + lines - + [""] * (dHeight - int(dHeight / 2)) - ) - elif valign == "b": - lines = [""] * dHeight + lines - else: - lines = lines + [""] * dHeight - - y = 0 - for l in lines: - if options["fields"] and field not in options["fields"]: - continue - - bits[y].append( - " " * lpad - + self._justify(l, width, self._align[field]) - + " " * rpad - ) - if options["border"]: - if options["vrules"] == ALL: - bits[y].append(self.vertical_char) - else: - bits[y].append(" ") - y += 1 - - # If vrules is FRAME, then we just appended a space at the end - # of the last field, when we really want a vertical character - for y in range(0, row_height): - if options["border"] and options["vrules"] == FRAME: - bits[y].pop() - bits[y].append(options["vertical_char"]) - - if options["border"] and options["hrules"] == ALL: - bits[row_height - 1].append("\n") - bits[row_height - 1].append(self._hrule) - - for y in range(0, row_height): - bits[y] = "".join(bits[y]) - - return "\n".join(bits) - - def paginate(self, page_length=58, **kwargs): - - pages = [] - kwargs["start"] = kwargs.get("start", 0) - true_end = kwargs.get("end", self.rowcount) - while True: - kwargs["end"] = min(kwargs["start"] + page_length, true_end) - pages.append(self.get_string(**kwargs)) - if kwargs["end"] == true_end: - break - kwargs["start"] += page_length - return "\f".join(pages) - - ############################## - # JSON STRING METHODS # - ############################## - def get_json_string(self, **kwargs): - - """Return string representation of JSON formatted table in the current state - - Arguments: - - none yet""" - - options = self._get_options(kwargs) - - objects = [self.field_names] - for row in self._get_rows(options): - objects.append(dict(zip(self._field_names, row))) - - return json.dumps(objects, indent=4, separators=(",", ": "), sort_keys=True) - - ############################## - # HTML STRING METHODS # - ############################## - - def get_html_string(self, **kwargs): - """Return string representation of HTML formatted version of table in current - state. - - Arguments: - - title - optional table title - start - index of first data row to include in output - end - index of last data row to include in output PLUS ONE (list slice style) - fields - names of fields (columns) to include - header - print a header showing field names (True or False) - border - print a border around the table (True or False) - hrules - controls printing of horizontal rules after rows. - Allowed values: ALL, FRAME, HEADER, NONE - vrules - controls printing of vertical rules between columns. - Allowed values: FRAME, ALL, NONE - int_format - controls formatting of integer data - float_format - controls formatting of floating point data - padding_width - number of spaces on either side of column data (only used if - left and right paddings are None) - left_padding_width - number of spaces on left hand side of column data - right_padding_width - number of spaces on right hand side of column data - sortby - name of field to sort rows by - sort_key - sorting key function, applied to data points before sorting - attributes - dictionary of name/value pairs to include as HTML attributes in the -
tag - xhtml - print
tags if True,
tags if false""" - - options = self._get_options(kwargs) - - if options["format"]: - string = self._get_formatted_html_string(options) - else: - string = self._get_simple_html_string(options) - - return string - - def _get_simple_html_string(self, options): - - lines = [] - if options["xhtml"]: - linebreak = "
" - else: - linebreak = "
" - - open_tag = ["") - lines.append("".join(open_tag)) - - # Title - title = options["title"] or self._title - if title: - cols = ( - len(options["fields"]) if options["fields"] else len(self.field_names) - ) - lines.append(" ") - lines.append(" " % (cols, title)) - lines.append(" ") - - # Headers - if options["header"]: - lines.append(" ") - for field in self._field_names: - if options["fields"] and field not in options["fields"]: - continue - lines.append( - " " % escape(field).replace("\n", linebreak) - ) - lines.append(" ") - - # Data - rows = self._get_rows(options) - formatted_rows = self._format_rows(rows, options) - for row in formatted_rows: - lines.append(" ") - for field, datum in zip(self._field_names, row): - if options["fields"] and field not in options["fields"]: - continue - lines.append( - " " % escape(datum).replace("\n", linebreak) - ) - lines.append(" ") - - lines.append("
%s
%s
%s
") - - return self._unicode("\n").join(lines) - - def _get_formatted_html_string(self, options): - - lines = [] - lpad, rpad = self._get_padding_widths(options) - if options["xhtml"]: - linebreak = "
" - else: - linebreak = "
" - - open_tag = ["") - lines.append("".join(open_tag)) - - # Title - title = options["title"] or self._title - if title: - cols = ( - len(options["fields"]) if options["fields"] else len(self.field_names) - ) - lines.append(" ") - lines.append(" %s" % (cols, title)) - lines.append(" ") - - # Headers - if options["header"]: - lines.append(" ") - for field in self._field_names: - if options["fields"] and field not in options["fields"]: - continue - lines.append( - ' %s' # noqa: E501 - % (lpad, rpad, escape(field).replace("\n", linebreak)) - ) - lines.append(" ") - - # Data - rows = self._get_rows(options) - formatted_rows = self._format_rows(rows, options) - aligns = [] - valigns = [] - for field in self._field_names: - aligns.append( - {"l": "left", "r": "right", "c": "center"}[self._align[field]] - ) - valigns.append( - {"t": "top", "m": "middle", "b": "bottom"}[self._valign[field]] - ) - for row in formatted_rows: - lines.append(" ") - for field, datum, align, valign in zip( - self._field_names, row, aligns, valigns - ): - if options["fields"] and field not in options["fields"]: - continue - lines.append( - ' %s' # noqa: E501 - % ( - lpad, - rpad, - align, - valign, - escape(datum).replace("\n", linebreak), - ) - ) - lines.append(" ") - lines.append("") - - return self._unicode("\n").join(lines) - - -############################## -# UNICODE WIDTH FUNCTIONS # -############################## - - -def _char_block_width(char): - # Basic Latin, which is probably the most common case - # if char in xrange(0x0021, 0x007e): - # if char >= 0x0021 and char <= 0x007e: - if 0x0021 <= char <= 0x007E: - return 1 - # Chinese, Japanese, Korean (common) - if 0x4E00 <= char <= 0x9FFF: - return 2 - # Hangul - if 0xAC00 <= char <= 0xD7AF: - return 2 - # Combining? - if unicodedata.combining(uni_chr(char)): - return 0 - # Hiragana and Katakana - if 0x3040 <= char <= 0x309F or 0x30A0 <= char <= 0x30FF: - return 2 - # Full-width Latin characters - if 0xFF01 <= char <= 0xFF60: - return 2 - # CJK punctuation - if 0x3000 <= char <= 0x303E: - return 2 - # Backspace and delete - if char in (0x0008, 0x007F): - return -1 - # Other control characters - elif char in (0x0000, 0x000F, 0x001F): - return 0 - # Take a guess - return 1 - - -def _str_block_width(val): - return sum(itermap(_char_block_width, itermap(ord, _re.sub("", val)))) - - -############################## -# TABLE FACTORIES # -############################## - - -def from_csv(fp, field_names=None, **kwargs): - fmtparams = {} - for param in [ - "delimiter", - "doublequote", - "escapechar", - "lineterminator", - "quotechar", - "quoting", - "skipinitialspace", - "strict", - ]: - if param in kwargs: - fmtparams[param] = kwargs.pop(param) - if fmtparams: - reader = csv.reader(fp, **fmtparams) - else: - dialect = csv.Sniffer().sniff(fp.read(1024)) - fp.seek(0) - reader = csv.reader(fp, dialect) - - table = PrettyTable(**kwargs) - if field_names: - table.field_names = field_names - else: - if py3k: - table.field_names = [x.strip() for x in next(reader)] - else: - table.field_names = [x.strip() for x in reader.next()] - - for row in reader: - table.add_row([x.strip() for x in row]) - - return table - - -def from_db_cursor(cursor, **kwargs): - if cursor.description: - table = PrettyTable(**kwargs) - table.field_names = [col[0] for col in cursor.description] - for row in cursor.fetchall(): - table.add_row(row) - return table - - -def from_json(json_string, **kwargs): - table = PrettyTable(**kwargs) - objects = json.loads(json_string) - table.field_names = objects[0] - for obj in objects[1:]: - row = [obj[key] for key in table.field_names] - table.add_row(row) - return table - - -class TableHandler(HTMLParser): - def __init__(self, **kwargs): - HTMLParser.__init__(self) - self.kwargs = kwargs - self.tables = [] - self.last_row = [] - self.rows = [] - self.max_row_width = 0 - self.active = None - self.last_content = "" - self.is_last_row_header = False - self.colspan = 0 - - def handle_starttag(self, tag, attrs): - self.active = tag - if tag == "th": - self.is_last_row_header = True - for (key, value) in attrs: - if key == "colspan": - self.colspan = int(value) - - def handle_endtag(self, tag): - if tag in ["th", "td"]: - stripped_content = self.last_content.strip() - self.last_row.append(stripped_content) - if self.colspan: - for i in range(1, self.colspan): - self.last_row.append("") - self.colspan = 0 - - if tag == "tr": - self.rows.append((self.last_row, self.is_last_row_header)) - self.max_row_width = max(self.max_row_width, len(self.last_row)) - self.last_row = [] - self.is_last_row_header = False - if tag == "table": - table = self.generate_table(self.rows) - self.tables.append(table) - self.rows = [] - self.last_content = " " - self.active = None - - def handle_data(self, data): - self.last_content += data - - def generate_table(self, rows): - """ - Generates from a list of rows a PrettyTable object. - """ - table = PrettyTable(**self.kwargs) - for row in self.rows: - if len(row[0]) < self.max_row_width: - appends = self.max_row_width - len(row[0]) - for i in range(1, appends): - row[0].append("-") - - if row[1]: - self.make_fields_unique(row[0]) - table.field_names = row[0] - else: - table.add_row(row[0]) - return table - - def make_fields_unique(self, fields): - """ - iterates over the row and make each field unique - """ - for i in range(0, len(fields)): - for j in range(i + 1, len(fields)): - if fields[i] == fields[j]: - fields[j] += "'" - - -def from_html(html_code, **kwargs): - """ - Generates a list of PrettyTables from a string of HTML code. Each in - the HTML becomes one PrettyTable object. - """ - - parser = TableHandler(**kwargs) - parser.feed(html_code) - return parser.tables - - -def from_html_one(html_code, **kwargs): - """ - Generates a PrettyTables from a string of HTML code which contains only a - single
- """ - - tables = from_html(html_code, **kwargs) - try: - assert len(tables) == 1 - except AssertionError: - raise Exception( - "More than one
in provided HTML code! Use from_html instead." - ) - return tables[0] - - -############################## -# MAIN (TEST FUNCTION) # -############################## - - -def main(): - print("Generated using setters:") - x = PrettyTable(["City name", "Area", "Population", "Annual Rainfall"]) - x.title = "Australian capital cities" - x.sortby = "Population" - x.reversesort = True - x.int_format["Area"] = "04" - x.float_format = "6.1" - x.align["City name"] = "l" # Left align city names - x.add_row(["Adelaide", 1295, 1158259, 600.5]) - x.add_row(["Brisbane", 5905, 1857594, 1146.4]) - x.add_row(["Darwin", 112, 120900, 1714.7]) - x.add_row(["Hobart", 1357, 205556, 619.5]) - x.add_row(["Sydney", 2058, 4336374, 1214.8]) - x.add_row(["Melbourne", 1566, 3806092, 646.9]) - x.add_row(["Perth", 5386, 1554769, 869.4]) - print(x) - - print - - print("Generated using constructor arguments:") - - y = PrettyTable( - ["City name", "Area", "Population", "Annual Rainfall"], - title="Australian capital cities", - sortby="Population", - reversesort=True, - int_format="04", - float_format="6.1", - max_width=12, - min_width=4, - align="c", - valign="t", - ) - y.align["City name"] = "l" # Left align city names - y.add_row(["Adelaide", 1295, 1158259, 600.5]) - y.add_row(["Brisbane", 5905, 1857594, 1146.4]) - y.add_row(["Darwin", 112, 120900, 1714.7]) - y.add_row(["Hobart", 1357, 205556, 619.5]) - y.add_row(["Sydney", 2058, 4336374, 1214.8]) - y.add_row(["Melbourne", 1566, 3806092, 646.9]) - y.add_row(["Perth", 5386, 1554769, 869.4]) - print(y) - - -if __name__ == "__main__": - main() diff --git a/chispa/row_comparer.py b/chispa/row_comparer.py index 886d2be..7835e5d 100644 --- a/chispa/row_comparer.py +++ b/chispa/row_comparer.py @@ -1,7 +1,11 @@ -from pyspark.sql import Row -from chispa.number_helpers import nan_safe_equality, nan_safe_approx_equality +from __future__ import annotations + import math +from pyspark.sql import Row + +from chispa.number_helpers import nan_safe_approx_equality, nan_safe_equality + def are_rows_equal(r1: Row, r2: Row) -> bool: return r1 == r2 @@ -16,7 +20,7 @@ def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool: d2 = r2.asDict() if allow_nan_equality: for key in d1.keys() & d2.keys(): - if not(nan_safe_equality(d1[key], d2[key])): + if not (nan_safe_equality(d1[key], d2[key])): return False return True else: @@ -33,13 +37,12 @@ def are_rows_approx_equal(r1: Row, r2: Row, precision: float, allow_nan_equality allEqual = True for key in d1.keys() & d2.keys(): if isinstance(d1[key], float) and isinstance(d2[key], float): - if allow_nan_equality and not(nan_safe_approx_equality(d1[key], d2[key], precision)): + if allow_nan_equality and not (nan_safe_approx_equality(d1[key], d2[key], precision)): allEqual = False - elif not(allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])): + elif not (allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])): allEqual = False elif abs(d1[key] - d2[key]) > precision: allEqual = False elif d1[key] != d2[key]: allEqual = False return allEqual - diff --git a/chispa/rows_comparer.py b/chispa/rows_comparer.py index d1c6e9f..30fdc96 100644 --- a/chispa/rows_comparer.py +++ b/chispa/rows_comparer.py @@ -1,14 +1,19 @@ +from __future__ import annotations + from itertools import zip_longest -from chispa.prettytable import PrettyTable -from chispa.bcolors import * + +from prettytable import PrettyTable + import chispa -from pyspark.sql.types import Row -from typing import List -from chispa.terminal_str_formatter import format_string -from chispa.default_formats import DefaultFormats +from chispa.formatting import FormattingConfig, format_string -def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=DefaultFormats()): +def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: FormattingConfig | None = None): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) + if rows1 != rows2: t = PrettyTable(["df1", "df2"]) zipped = list(zip_longest(rows1, rows2)) @@ -16,10 +21,10 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=Defa for r1, r2 in zipped: if r1 is None and r2 is not None: - t.add_row([None, format_string(r2, formats.mismatched_rows)]) + t.add_row([None, format_string(str(r2), formats.mismatched_rows)]) all_rows_equal = False elif r1 is not None and r2 is None: - t.add_row([format_string(r1, formats.mismatched_rows), None]) + t.add_row([format_string(str(r1), formats.mismatched_rows), None]) all_rows_equal = False else: r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) @@ -37,11 +42,23 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=Defa r2_res = ", ".join(r2_string) t.add_row([r1_res, r2_res]) - if all_rows_equal == False: + if all_rows_equal is False: raise chispa.DataFramesNotEqualError("\n" + t.get_string()) -def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False, formats=DefaultFormats()): +def assert_generic_rows_equality( + rows1, + rows2, + row_equality_fun, + row_equality_fun_args, + underline_cells=False, + formats: FormattingConfig | None = None, +): + if not formats: + formats = FormattingConfig() + elif not isinstance(formats, FormattingConfig): + formats = FormattingConfig._from_arbitrary_dataclass(formats) + df1_rows = rows1 df2_rows = rows2 zipped = list(zip_longest(df1_rows, df2_rows)) @@ -49,14 +66,20 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu all_rows_equal = True for r1, r2 in zipped: # rows are not equal when one is None and the other isn't - if (r1 is not None and r2 is None) or (r2 is not None and r1 is None): + if (r1 is None) ^ (r2 is None): all_rows_equal = False - t.add_row([format_string(r1, formats.mismatched_rows), format_string(r2, formats.mismatched_rows)]) + t.add_row([ + format_string(str(r1), formats.mismatched_rows), + format_string(str(r2), formats.mismatched_rows), + ]) # rows are equal elif row_equality_fun(r1, r2, *row_equality_fun_args): r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__)) r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__)) - t.add_row([format_string(r1_string, formats.matched_rows), format_string(r2_string, formats.matched_rows)]) + t.add_row([ + format_string(r1_string, formats.matched_rows), + format_string(r2_string, formats.matched_rows), + ]) # otherwise, rows aren't equal else: r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) @@ -74,5 +97,5 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu r2_res = ", ".join(r2_string) t.add_row([r1_res, r2_res]) - if all_rows_equal == False: + if all_rows_equal is False: raise chispa.DataFramesNotEqualError("\n" + t.get_string()) diff --git a/chispa/schema_comparer.py b/chispa/schema_comparer.py index 32d03f7..c2e2150 100644 --- a/chispa/schema_comparer.py +++ b/chispa/schema_comparer.py @@ -1,11 +1,16 @@ -from chispa.prettytable import PrettyTable -from chispa.bcolors import * +from __future__ import annotations + from itertools import zip_longest +from prettytable import PrettyTable + +from chispa.bcolors import blue + class SchemasNotEqualError(Exception): - """The schemas are not equal""" - pass + """The schemas are not equal""" + + pass def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False): @@ -51,7 +56,6 @@ def assert_basic_schema_equality(s1, s2): raise SchemasNotEqualError("\n" + t.get_string()) - # deprecate this. ignore_nullable should be a flag. def assert_schema_equality_ignore_nullable(s1, s2): if not are_schemas_equal_ignore_nullable(s1, s2): @@ -101,9 +105,9 @@ def are_datatypes_equal_ignore_nullable(dt1, dt2): """ if dt1.typeName() == dt2.typeName(): # Account for array types by inspecting elementType. - if dt1.typeName() == 'array': + if dt1.typeName() == "array": return are_datatypes_equal_ignore_nullable(dt1.elementType, dt2.elementType) - elif dt1.typeName() == 'struct': + elif dt1.typeName() == "struct": return are_schemas_equal_ignore_nullable(dt1, dt2) else: return True diff --git a/chispa/structfield_comparer.py b/chispa/structfield_comparer.py index dc448a1..f1b4782 100644 --- a/chispa/structfield_comparer.py +++ b/chispa/structfield_comparer.py @@ -1 +1,5 @@ -from chispa.schema_comparer import are_structfields_equal \ No newline at end of file +from __future__ import annotations + +from chispa.schema_comparer import are_structfields_equal + +__all__ = ("are_structfields_equal",) diff --git a/chispa/terminal_str_formatter.py b/chispa/terminal_str_formatter.py deleted file mode 100644 index ef5ace6..0000000 --- a/chispa/terminal_str_formatter.py +++ /dev/null @@ -1,30 +0,0 @@ -def format_string(input, formats): - formatting = { - "nc": '\033[0m', # No Color, reset all - "bold": '\033[1m', - "underline": '\033[4m', - "blink": '\033[5m', - "blue": '\033[34m', - "white": '\033[97m', - "red": '\033[31m', - "invert": '\033[7m', - "hide": '\033[8m', - "black": '\033[30m', - "green": '\033[32m', - "yellow": '\033[33m', - "purple": '\033[35m', - "cyan": '\033[36m', - "light_gray": '\033[37m', - "dark_gray": '\033[30m', - "light_red": '\033[31m', - "light_green": '\033[32m', - "light_yellow": '\033[93m', - "light_blue": '\033[34m', - "light_purple": '\033[35m', - "light_cyan": '\033[36m', - } - formatted = input - for format in formats: - s = formatting[format] - formatted = s + str(formatted) + s - return formatting["nc"] + str(formatted) + formatting["nc"] \ No newline at end of file diff --git a/ci/environment-py39.yml b/ci/environment-py39.yml index e16af1d..82d1ed4 100644 --- a/ci/environment-py39.yml +++ b/ci/environment-py39.yml @@ -8,4 +8,3 @@ dependencies: - pytest-describe - pyspark - findspark - diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index edd357e..03becb6 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -3,6 +3,8 @@ https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages """ +from __future__ import annotations + from pathlib import Path import mkdocs_gen_files diff --git a/mkdocs.yml b/mkdocs.yml index d17af8c..260b226 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,7 +22,7 @@ plugins: handlers: python: options: - docstring_style: sphinx + docstring_style: google docstring_options: show_if_no_docstring: true show_source: true @@ -69,4 +69,4 @@ markdown_extensions: - pymdownx.arithmatex: generic: true - markdown_include.include: - base_path: . \ No newline at end of file + base_path: . diff --git a/poetry.lock b/poetry.lock index fea264b..8612623 100644 --- a/poetry.lock +++ b/poetry.lock @@ -43,6 +43,17 @@ files = [ {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, ] +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -167,6 +178,84 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.6.0" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dff044f661f59dace805eedb4a7404c573b6ff0cdba4a524141bc63d7be5c7fd"}, + {file = "coverage-7.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8659fd33ee9e6ca03950cfdcdf271d645cf681609153f218826dd9805ab585c"}, + {file = "coverage-7.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7792f0ab20df8071d669d929c75c97fecfa6bcab82c10ee4adb91c7a54055463"}, + {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4b3cd1ca7cd73d229487fa5caca9e4bc1f0bca96526b922d61053ea751fe791"}, + {file = "coverage-7.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7e128f85c0b419907d1f38e616c4f1e9f1d1b37a7949f44df9a73d5da5cd53c"}, + {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a94925102c89247530ae1dab7dc02c690942566f22e189cbd53579b0693c0783"}, + {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:dcd070b5b585b50e6617e8972f3fbbee786afca71b1936ac06257f7e178f00f6"}, + {file = "coverage-7.6.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d50a252b23b9b4dfeefc1f663c568a221092cbaded20a05a11665d0dbec9b8fb"}, + {file = "coverage-7.6.0-cp310-cp310-win32.whl", hash = "sha256:0e7b27d04131c46e6894f23a4ae186a6a2207209a05df5b6ad4caee6d54a222c"}, + {file = "coverage-7.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:54dece71673b3187c86226c3ca793c5f891f9fc3d8aa183f2e3653da18566169"}, + {file = "coverage-7.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7b525ab52ce18c57ae232ba6f7010297a87ced82a2383b1afd238849c1ff933"}, + {file = "coverage-7.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bea27c4269234e06f621f3fac3925f56ff34bc14521484b8f66a580aacc2e7d"}, + {file = "coverage-7.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed8d1d1821ba5fc88d4a4f45387b65de52382fa3ef1f0115a4f7a20cdfab0e94"}, + {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c322ef2bbe15057bc4bf132b525b7e3f7206f071799eb8aa6ad1940bcf5fb1"}, + {file = "coverage-7.6.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03cafe82c1b32b770a29fd6de923625ccac3185a54a5e66606da26d105f37dac"}, + {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0d1b923fc4a40c5832be4f35a5dab0e5ff89cddf83bb4174499e02ea089daf57"}, + {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4b03741e70fb811d1a9a1d75355cf391f274ed85847f4b78e35459899f57af4d"}, + {file = "coverage-7.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a73d18625f6a8a1cbb11eadc1d03929f9510f4131879288e3f7922097a429f63"}, + {file = "coverage-7.6.0-cp311-cp311-win32.whl", hash = "sha256:65fa405b837060db569a61ec368b74688f429b32fa47a8929a7a2f9b47183713"}, + {file = "coverage-7.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:6379688fb4cfa921ae349c76eb1a9ab26b65f32b03d46bb0eed841fd4cb6afb1"}, + {file = "coverage-7.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f7db0b6ae1f96ae41afe626095149ecd1b212b424626175a6633c2999eaad45b"}, + {file = "coverage-7.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bbdf9a72403110a3bdae77948b8011f644571311c2fb35ee15f0f10a8fc082e8"}, + {file = "coverage-7.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc44bf0315268e253bf563f3560e6c004efe38f76db03a1558274a6e04bf5d5"}, + {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da8549d17489cd52f85a9829d0e1d91059359b3c54a26f28bec2c5d369524807"}, + {file = "coverage-7.6.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0086cd4fc71b7d485ac93ca4239c8f75732c2ae3ba83f6be1c9be59d9e2c6382"}, + {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fad32ee9b27350687035cb5fdf9145bc9cf0a094a9577d43e909948ebcfa27b"}, + {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:044a0985a4f25b335882b0966625270a8d9db3d3409ddc49a4eb00b0ef5e8cee"}, + {file = "coverage-7.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:76d5f82213aa78098b9b964ea89de4617e70e0d43e97900c2778a50856dac605"}, + {file = "coverage-7.6.0-cp312-cp312-win32.whl", hash = "sha256:3c59105f8d58ce500f348c5b56163a4113a440dad6daa2294b5052a10db866da"}, + {file = "coverage-7.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:ca5d79cfdae420a1d52bf177de4bc2289c321d6c961ae321503b2ca59c17ae67"}, + {file = "coverage-7.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d39bd10f0ae453554798b125d2f39884290c480f56e8a02ba7a6ed552005243b"}, + {file = "coverage-7.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:beb08e8508e53a568811016e59f3234d29c2583f6b6e28572f0954a6b4f7e03d"}, + {file = "coverage-7.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2e16f4cd2bc4d88ba30ca2d3bbf2f21f00f382cf4e1ce3b1ddc96c634bc48ca"}, + {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6616d1c9bf1e3faea78711ee42a8b972367d82ceae233ec0ac61cc7fec09fa6b"}, + {file = "coverage-7.6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad4567d6c334c46046d1c4c20024de2a1c3abc626817ae21ae3da600f5779b44"}, + {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d17c6a415d68cfe1091d3296ba5749d3d8696e42c37fca5d4860c5bf7b729f03"}, + {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9146579352d7b5f6412735d0f203bbd8d00113a680b66565e205bc605ef81bc6"}, + {file = "coverage-7.6.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:cdab02a0a941af190df8782aafc591ef3ad08824f97850b015c8c6a8b3877b0b"}, + {file = "coverage-7.6.0-cp38-cp38-win32.whl", hash = "sha256:df423f351b162a702c053d5dddc0fc0ef9a9e27ea3f449781ace5f906b664428"}, + {file = "coverage-7.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:f2501d60d7497fd55e391f423f965bbe9e650e9ffc3c627d5f0ac516026000b8"}, + {file = "coverage-7.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7221f9ac9dad9492cecab6f676b3eaf9185141539d5c9689d13fd6b0d7de840c"}, + {file = "coverage-7.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ddaaa91bfc4477d2871442bbf30a125e8fe6b05da8a0015507bfbf4718228ab2"}, + {file = "coverage-7.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4cbe651f3904e28f3a55d6f371203049034b4ddbce65a54527a3f189ca3b390"}, + {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:831b476d79408ab6ccfadaaf199906c833f02fdb32c9ab907b1d4aa0713cfa3b"}, + {file = "coverage-7.6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c3d091059ad0b9c59d1034de74a7f36dcfa7f6d3bde782c49deb42438f2450"}, + {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4d5fae0a22dc86259dee66f2cc6c1d3e490c4a1214d7daa2a93d07491c5c04b6"}, + {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:07ed352205574aad067482e53dd606926afebcb5590653121063fbf4e2175166"}, + {file = "coverage-7.6.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:49c76cdfa13015c4560702574bad67f0e15ca5a2872c6a125f6327ead2b731dd"}, + {file = "coverage-7.6.0-cp39-cp39-win32.whl", hash = "sha256:482855914928c8175735a2a59c8dc5806cf7d8f032e4820d52e845d1f731dca2"}, + {file = "coverage-7.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:543ef9179bc55edfd895154a51792b01c017c87af0ebaae092720152e19e42ca"}, + {file = "coverage-7.6.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:6fe885135c8a479d3e37a7aae61cbd3a0fb2deccb4dda3c25f92a49189f766d6"}, + {file = "coverage-7.6.0.tar.gz", hash = "sha256:289cc803fa1dc901f84701ac10c9ee873619320f2f9aff38794db4a4a0268d51"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -181,6 +270,22 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "filelock" +version = "3.15.4" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +typing = ["typing-extensions (>=4.8)"] + [[package]] name = "findspark" version = "1.4.2" @@ -224,6 +329,20 @@ files = [ astunparse = {version = ">=1.6", markers = "python_version < \"3.9\""} colorama = ">=0.4" +[[package]] +name = "identify" +version = "2.6.0" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.6.0-py2.py3-none-any.whl", hash = "sha256:e79ae4406387a9d300332b5fd366d8994f1525e8414984e1a59e058b2eda2dd0"}, + {file = "identify-2.6.0.tar.gz", hash = "sha256:cb171c685bdc31bcc4c1734698736a7d5b6c8bf2e0c15117f4d469c8640ae5cf"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.7" @@ -586,6 +705,17 @@ files = [ griffe = ">=0.47" mkdocstrings = ">=0.25" +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "packaging" version = "24.1" @@ -649,6 +779,41 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.3.3" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pre_commit-3.3.3-py2.py3-none-any.whl", hash = "sha256:10badb65d6a38caff29703362271d7dca483d01da88f9d7e05d0b97171c136cb"}, + {file = "pre_commit-3.3.3.tar.gz", hash = "sha256:a2256f489cd913d575c145132ae196fe335da32d91a8294b7afe6622335dd023"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + +[[package]] +name = "prettytable" +version = "3.10.2" +description = "A simple Python library for easily displaying tabular data in a visually appealing ASCII table format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "prettytable-3.10.2-py3-none-any.whl", hash = "sha256:1cbfdeb4bcc73976a778a0fb33cb6d752e75396f16574dcb3e2d6332fd93c76a"}, + {file = "prettytable-3.10.2.tar.gz", hash = "sha256:29ec6c34260191d42cd4928c28d56adec360ac2b1208a26c7e4f14b90cc8bc84"}, +] + +[package.dependencies] +wcwidth = "*" + +[package.extras] +tests = ["pytest", "pytest-cov", "pytest-lazy-fixtures"] + [[package]] name = "py4j" version = "0.10.9.7" @@ -734,6 +899,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-cov" +version = "5.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "pytest-describe" version = "2.2.0" @@ -1006,6 +1189,26 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.26.3" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, + {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "watchdog" version = "4.0.1" @@ -1050,6 +1253,17 @@ files = [ [package.extras] watchmedo = ["PyYAML (>=3.10)"] +[[package]] +name = "wcwidth" +version = "0.2.13" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, +] + [[package]] name = "wheel" version = "0.43.0" @@ -1082,4 +1296,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "28562e740b45e6351b1c6820acaeeb6ba6ba102bc9897e7b917c6f7624d4975d" +content-hash = "9fde9a932fca40538936262263439debec9162feea75797fc973fe6a92a770b1" diff --git a/pyproject.toml b/pyproject.toml index b288aff..167bfda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,12 +33,15 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.8,<4.0" +prettytable = "^3.10.2" [tool.poetry.group.dev.dependencies] pytest = "7.4.2" pyspark = ">3.0.0" findspark = "1.4.2" pytest-describe = "^2.1.0" +pytest-cov = "^5.0.0" +pre-commit = "3.3.3" [tool.poetry.group.mkdocs.dependencies] mkdocs = "^1.6.0" @@ -55,3 +58,27 @@ optional = true [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +target-version = "py39" +line-length = 120 +fix = true + +[tool.ruff.format] +preview = true + +[tool.ruff.lint] +select = ["E", "F", "I", "RUF", "UP"] +ignore = [ + # Line too long + "E501" +] + +[tool.ruff.lint.flake8-type-checking] +strict = true + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101", "S603"] + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] diff --git a/tests/conftest.py b/tests/conftest.py index 3062077..5875d42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,25 @@ +from __future__ import annotations + import pytest -from dataclasses import dataclass -from chispa import Chispa -@dataclass -class MyFormats: - mismatched_rows = ["light_yellow"] - matched_rows = ["cyan", "bold"] - mismatched_cells = ["purple"] - matched_cells = ["blue"] +from chispa.formatting import FormattingConfig + @pytest.fixture() def my_formats(): - return MyFormats() + return FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) @pytest.fixture() def my_chispa(): - return Chispa(formats=MyFormats()) + return FormattingConfig( + mismatched_rows={"color": "light_yellow"}, + matched_rows={"color": "cyan", "style": "bold"}, + mismatched_cells={"color": "purple"}, + matched_cells={"color": "blue"}, + ) diff --git a/tests/formatting/test_formats.py b/tests/formatting/test_formats.py new file mode 100644 index 0000000..99f0d3b --- /dev/null +++ b/tests/formatting/test_formats.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import re + +import pytest + +from chispa.formatting import Color, Format, Style + + +def test_format_from_dict_valid(): + format_dict = {"color": "blue", "style": ["bold", "underline"]} + format_instance = Format.from_dict(format_dict) + assert format_instance.color == Color.BLUE + assert format_instance.style == [Style.BOLD, Style.UNDERLINE] + + +def test_format_from_dict_invalid_color(): + format_dict = {"color": "invalid_color", "style": ["bold"]} + with pytest.raises(ValueError) as exc_info: + Format.from_dict(format_dict) + assert str(exc_info.value) == ( + "Invalid color name: invalid_color. Valid color names are " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white']" + ) + + +def test_format_from_dict_invalid_style(): + format_dict = {"color": "blue", "style": ["invalid_style"]} + with pytest.raises(ValueError) as exc_info: + Format.from_dict(format_dict) + assert str(exc_info.value) == ( + "Invalid style name: invalid_style. Valid style names are " "['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_dict_invalid_key(): + format_dict = {"invalid_key": "value"} + try: + Format.from_dict(format_dict) + except ValueError as e: + error_message = str(e) + assert re.match( + r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", + error_message, + ) + + +def test_format_from_list_valid(): + values = ["blue", "bold", "underline"] + format_instance = Format.from_list(values) + assert format_instance.color == Color.BLUE + assert format_instance.style == [Style.BOLD, Style.UNDERLINE] + + +def test_format_from_list_invalid_color(): + values = ["invalid_color", "bold"] + with pytest.raises(ValueError) as exc_info: + Format.from_list(values) + assert str(exc_info.value) == ( + "Invalid value: invalid_color. Valid values are colors: " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white'] and styles: ['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_list_invalid_style(): + values = ["blue", "invalid_style"] + with pytest.raises(ValueError) as exc_info: + Format.from_list(values) + assert str(exc_info.value) == ( + "Invalid value: invalid_style. Valid values are colors: " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white'] and styles: ['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_list_non_string_elements(): + values = ["blue", 123] + with pytest.raises(ValueError) as exc_info: + Format.from_list(values) + assert str(exc_info.value) == "All elements in the list must be strings" + + +def test_format_from_dict_empty(): + format_dict = {} + format_instance = Format.from_dict(format_dict) + assert format_instance.color is None + assert format_instance.style is None + + +def test_format_from_list_empty(): + values = [] + format_instance = Format.from_list(values) + assert format_instance.color is None + assert format_instance.style is None diff --git a/tests/formatting/test_formatting_config.py b/tests/formatting/test_formatting_config.py new file mode 100644 index 0000000..3214ac8 --- /dev/null +++ b/tests/formatting/test_formatting_config.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import re + +import pytest + +from chispa.formatting import Color, FormattingConfig, Style + + +def test_default_mismatched_rows(): + config = FormattingConfig() + assert config.mismatched_rows.color == Color.RED + assert config.mismatched_rows.style is None + + +def test_default_matched_rows(): + config = FormattingConfig() + assert config.matched_rows.color == Color.BLUE + assert config.matched_rows.style is None + + +def test_default_mismatched_cells(): + config = FormattingConfig() + assert config.mismatched_cells.color == Color.RED + assert config.mismatched_cells.style == [Style.UNDERLINE] + + +def test_default_matched_cells(): + config = FormattingConfig() + assert config.matched_cells.color == Color.BLUE + assert config.matched_cells.style is None + + +def test_custom_mismatched_rows(): + config = FormattingConfig(mismatched_rows={"color": "green", "style": ["bold", "underline"]}) + assert config.mismatched_rows.color == Color.GREEN + assert config.mismatched_rows.style == [Style.BOLD, Style.UNDERLINE] + + +def test_custom_matched_rows(): + config = FormattingConfig(matched_rows={"color": "yellow"}) + assert config.matched_rows.color == Color.YELLOW + assert config.matched_rows.style is None + + +def test_custom_mismatched_cells(): + config = FormattingConfig(mismatched_cells={"color": "purple", "style": ["blink"]}) + assert config.mismatched_cells.color == Color.PURPLE + assert config.mismatched_cells.style == [Style.BLINK] + + +def test_custom_matched_cells(): + config = FormattingConfig(matched_cells={"color": "cyan", "style": ["invert", "hide"]}) + assert config.matched_cells.color == Color.CYAN + assert config.matched_cells.style == [Style.INVERT, Style.HIDE] + + +def test_invalid_color(): + with pytest.raises(ValueError) as exc_info: + FormattingConfig(mismatched_rows={"color": "invalid_color"}) + assert str(exc_info.value) == ( + "Invalid color name: invalid_color. Valid color names are " + "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " + "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " + "'light_cyan', 'white']" + ) + + +def test_invalid_style(): + with pytest.raises(ValueError) as exc_info: + FormattingConfig(mismatched_rows={"style": ["invalid_style"]}) + assert str(exc_info.value) == ( + "Invalid style name: invalid_style. Valid style names are " "['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_invalid_key(): + try: + FormattingConfig(mismatched_rows={"invalid_key": "value"}) + except ValueError as e: + error_message = str(e) + assert re.match( + r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", + error_message, + ) diff --git a/tests/formatting/test_terminal_string_formatter.py b/tests/formatting/test_terminal_string_formatter.py new file mode 100644 index 0000000..d6d4959 --- /dev/null +++ b/tests/formatting/test_terminal_string_formatter.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from chispa.formatting import RESET, format_string +from chispa.formatting.formats import Color, Format, Style + + +def test_format_with_enum_inputs(): + format = Format(color=Color.BLUE, style=[Style.BOLD, Style.UNDERLINE]) + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Style.BOLD.value}{Style.UNDERLINE.value}{Color.BLUE.value}Hello, World!{RESET}" + assert formatted_string == expected_string + + +def test_format_with_no_style(): + format = Format(color=Color.GREEN, style=[]) + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Color.GREEN.value}Hello, World!{RESET}" + assert formatted_string == expected_string + + +def test_format_with_no_color(): + format = Format(color=None, style=[Style.BLINK]) + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Style.BLINK.value}Hello, World!{RESET}" + assert formatted_string == expected_string + + +def test_format_with_no_color_or_style(): + format = Format(color=None, style=[]) + formatted_string = format_string("Hello, World!", format) + expected_string = "Hello, World!" + assert formatted_string == expected_string diff --git a/tests/spark.py b/tests/spark.py index 475955f..b5df699 100644 --- a/tests/spark.py +++ b/tests/spark.py @@ -1,7 +1,5 @@ -from pyspark.sql import SparkSession +from __future__ import annotations -spark = SparkSession.builder \ - .master("local") \ - .appName("chispa") \ - .getOrCreate() +from pyspark.sql import SparkSession +spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() diff --git a/tests/test_column_comparer.py b/tests/test_column_comparer.py index cf32142..1101545 100644 --- a/tests/test_column_comparer.py +++ b/tests/test_column_comparer.py @@ -1,14 +1,17 @@ +from __future__ import annotations + import pytest +from chispa import ColumnsNotEqualError, assert_approx_column_equality, assert_column_equality + from .spark import spark -from chispa import * def describe_assert_column_equality(): def it_throws_error_with_data_mismatch(): data = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] df = spark.createDataFrame(data, ["name", "expected_name"]) - with pytest.raises(ColumnsNotEqualError) as e_info: + with pytest.raises(ColumnsNotEqualError): assert_column_equality(df, "name", "expected_name") def it_doesnt_throw_without_mismatch(): @@ -16,7 +19,6 @@ def it_doesnt_throw_without_mismatch(): df = spark.createDataFrame(data, ["name", "expected_name"]) assert_column_equality(df, "name", "expected_name") - def it_works_with_integer_values(): data = [(1, 1), (10, 10), (8, 8), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) @@ -25,28 +27,24 @@ def it_works_with_integer_values(): def describe_assert_approx_column_equality(): def it_works_with_no_mismatches(): - data = [(1.1, 1.1), (1.0004, 1.0005), (.4, .45), (None, None)] + data = [(1.1, 1.1), (1.0004, 1.0005), (0.4, 0.45), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) assert_approx_column_equality(df, "num1", "num2", 0.1) - def it_throws_when_difference_is_bigger_than_precision(): - data = [(1.5, 1.1), (1.0004, 1.0005), (.4, .45)] + data = [(1.5, 1.1), (1.0004, 1.0005), (0.4, 0.45)] df = spark.createDataFrame(data, ["num1", "num2"]) - with pytest.raises(ColumnsNotEqualError) as e_info: + with pytest.raises(ColumnsNotEqualError): assert_approx_column_equality(df, "num1", "num2", 0.1) - def it_throws_when_comparing_floats_with_none(): data = [(1.1, 1.1), (2.2, 2.2), (3.3, None)] df = spark.createDataFrame(data, ["num1", "num2"]) - with pytest.raises(ColumnsNotEqualError) as e_info: + with pytest.raises(ColumnsNotEqualError): assert_approx_column_equality(df, "num1", "num2", 0.1) - def it_throws_when_comparing_none_with_floats(): data = [(1.1, 1.1), (2.2, 2.2), (None, 3.3)] df = spark.createDataFrame(data, ["num1", "num2"]) - with pytest.raises(ColumnsNotEqualError) as e_info: + with pytest.raises(ColumnsNotEqualError): assert_approx_column_equality(df, "num1", "num2", 0.1) - diff --git a/tests/test_dataframe_comparer.py b/tests/test_dataframe_comparer.py index bb90765..9f14c20 100644 --- a/tests/test_dataframe_comparer.py +++ b/tests/test_dataframe_comparer.py @@ -1,11 +1,15 @@ +from __future__ import annotations + +import math + import pytest +from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from .spark import * -from chispa import * +from chispa import DataFramesNotEqualError, assert_approx_df_equality, assert_df_equality from chispa.dataframe_comparer import are_dfs_equal from chispa.schema_comparer import SchemasNotEqualError -import math -from pyspark.sql.types import StringType, IntegerType, StructType, StructField + +from .spark import spark def describe_assert_df_equality(): @@ -14,10 +18,9 @@ def it_throws_with_schema_mismatches(): df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(SchemasNotEqualError) as e_info: + with pytest.raises(SchemasNotEqualError): assert_df_equality(df1, df2) - def it_can_work_with_different_row_orders(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -25,7 +28,6 @@ def it_can_work_with_different_row_orders(): df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) - def it_can_work_with_different_row_orders_with_a_flag(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -33,7 +35,6 @@ def it_can_work_with_different_row_orders_with_a_flag(): df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, ignore_row_order=True) - def it_can_work_with_different_row_and_column_orders(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -41,7 +42,6 @@ def it_can_work_with_different_row_and_column_orders(): df2 = spark.createDataFrame(data2, ["name", "num"]) assert_df_equality(df1, df2, ignore_row_order=True, ignore_column_order=True) - def it_raises_for_row_insensitive_with_diff_content(): data1 = [(1, "XXXX"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -50,133 +50,111 @@ def it_raises_for_row_insensitive_with_diff_content(): with pytest.raises(DataFramesNotEqualError): assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) - def it_throws_with_schema_column_order_mismatch(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) data2 = [("jose", 1), ("li", 1)] df2 = spark.createDataFrame(data2, ["name", "num"]) - with pytest.raises(SchemasNotEqualError) as e_info: + with pytest.raises(SchemasNotEqualError): assert_df_equality(df1, df2) - def it_does_not_throw_on_schema_column_order_mismatch_with_transforms(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) data2 = [("jose", 1), ("li", 2)] df2 = spark.createDataFrame(data2, ["name", "num"]) - assert_df_equality(df1, df2, transforms=[ - lambda df: df.select(sorted(df.columns)) - ]) - + assert_df_equality(df1, df2, transforms=[lambda df: df.select(sorted(df.columns))]) def it_throws_with_schema_mismatch(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "different_name"]) data2 = [("jose", 1), ("li", 2)] df2 = spark.createDataFrame(data2, ["name", "num"]) - with pytest.raises(SchemasNotEqualError) as e_info: + with pytest.raises(SchemasNotEqualError): assert_df_equality(df1, df2, transforms=[lambda df: df.select(sorted(df.columns))]) - def it_throws_with_content_mismatches(): data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_df_equality(df1, df2) - def it_throws_with_length_mismatches(): data1 = [("jose", "jose"), ("li", "li"), ("laura", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) data2 = [("jose", "jose"), ("li", "li")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_df_equality(df1, df2) - def it_can_consider_nan_values_equal(): - data1 = [(float('nan'), "jose"), (2.0, "li")] + data1 = [(float("nan"), "jose"), (2.0, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) - data2 = [(float('nan'), "jose"), (2.0, "li")] + data2 = [(float("nan"), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, allow_nan_equality=True) - def it_does_not_consider_nan_values_equal_by_default(): - data1 = [(float('nan'), "jose"), (2.0, "li")] + data1 = [(float("nan"), "jose"), (2.0, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) - data2 = [(float('nan'), "jose"), (2.0, "li")] + data2 = [(float("nan"), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_df_equality(df1, df2, allow_nan_equality=False) - def it_can_ignore_metadata(): rows_data = [("jose", 1), ("li", 2), ("luisa", 3)] - schema1 = StructType( - [ - StructField("name", StringType(), True, {"hi": "no"}), - StructField("age", IntegerType(), True), - ] - ) - schema2 = StructType( - [ - StructField("name", StringType(), True, {"hi": "whatever"}), - StructField("age", IntegerType(), True), - ] - ) + schema1 = StructType([ + StructField("name", StringType(), True, {"hi": "no"}), + StructField("age", IntegerType(), True), + ]) + schema2 = StructType([ + StructField("name", StringType(), True, {"hi": "whatever"}), + StructField("age", IntegerType(), True), + ]) df1 = spark.createDataFrame(rows_data, schema1) df2 = spark.createDataFrame(rows_data, schema2) assert_df_equality(df1, df2, ignore_metadata=True) - def it_catches_mismatched_metadata(): rows_data = [("jose", 1), ("li", 2), ("luisa", 3)] - schema1 = StructType( - [ - StructField("name", StringType(), True, {"hi": "no"}), - StructField("age", IntegerType(), True), - ] - ) - schema2 = StructType( - [ - StructField("name", StringType(), True, {"hi": "whatever"}), - StructField("age", IntegerType(), True), - ] - ) + schema1 = StructType([ + StructField("name", StringType(), True, {"hi": "no"}), + StructField("age", IntegerType(), True), + ]) + schema2 = StructType([ + StructField("name", StringType(), True, {"hi": "whatever"}), + StructField("age", IntegerType(), True), + ]) df1 = spark.createDataFrame(rows_data, schema1) df2 = spark.createDataFrame(rows_data, schema2) - with pytest.raises(SchemasNotEqualError) as e_info: + with pytest.raises(SchemasNotEqualError): assert_df_equality(df1, df2) - def describe_are_dfs_equal(): def it_returns_false_with_schema_mismatches(): data1 = [(1, "jose"), (2, "li"), (3, "laura")] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert are_dfs_equal(df1, df2) == False - + assert are_dfs_equal(df1, df2) is False def it_returns_false_with_content_mismatches(): data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert are_dfs_equal(df1, df2) == False - + assert are_dfs_equal(df1, df2) is False def it_returns_true_when_dfs_are_same(): data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - assert are_dfs_equal(df1, df2) == True + assert are_dfs_equal(df1, df2) is True def describe_assert_approx_df_equality(): @@ -185,20 +163,17 @@ def it_throws_with_content_mismatch(): df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1.0, "jose"), (1.05, "li"), (1.0, "laura"), (None, "hi")] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_approx_df_equality(df1, df2, 0.1) - def it_throws_with_with_length_mismatch(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1.0, "jose"), (1.05, "li")] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_approx_df_equality(df1, df2, 0.1) - - def it_does_not_throw_with_no_mismatch(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) @@ -206,7 +181,6 @@ def it_does_not_throw_with_no_mismatch(): df2 = spark.createDataFrame(data2, ["num", "expected_name"]) assert_approx_df_equality(df1, df2, 0.1) - def it_does_not_throw_with_different_row_col_order(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) @@ -214,10 +188,21 @@ def it_does_not_throw_with_different_row_col_order(): df2 = spark.createDataFrame(data2, ["expected_name", "num"]) assert_approx_df_equality(df1, df2, 0.1, ignore_row_order=True, ignore_column_order=True) - def it_does_not_throw_with_nan_values(): - data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None), (float("nan"), "buk")] + data1 = [ + (1.0, "jose"), + (1.1, "li"), + (1.2, "laura"), + (None, None), + (float("nan"), "buk"), + ] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) - data2 = [(1.0, "jose"), (1.05, "li"), (1.2, "laura"), (None, None), (math.nan, "buk")] + data2 = [ + (1.0, "jose"), + (1.05, "li"), + (1.2, "laura"), + (None, None), + (math.nan, "buk"), + ] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) assert_approx_df_equality(df1, df2, 0.1, allow_nan_equality=True) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py new file mode 100644 index 0000000..a38b83d --- /dev/null +++ b/tests/test_deprecated.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import warnings +from dataclasses import dataclass + +import pytest + +from chispa import DataFramesNotEqualError, assert_basic_rows_equality +from chispa.default_formats import DefaultFormats +from chispa.formatting import FormattingConfig + +from .spark import spark + + +def test_default_formats_deprecation_warning(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + DefaultFormats() + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "DefaultFormats is deprecated" in str(w[-1].message) + + +def test_that_default_formats_still_works(): + data1 = [(1, "jose"), (2, "li"), (3, "laura")] + df1 = spark.createDataFrame(data1, ["num", "expected_name"]) + data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] + df2 = spark.createDataFrame(data2, ["name", "expected_name"]) + with pytest.raises(DataFramesNotEqualError): + assert_basic_rows_equality(df1.collect(), df2.collect(), formats=DefaultFormats()) + + +def test_deprecated_arbitrary_dataclass(): + data1 = [(1, "jose"), (2, "li"), (3, "laura")] + df1 = spark.createDataFrame(data1, ["num", "expected_name"]) + data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] + df2 = spark.createDataFrame(data2, ["name", "expected_name"]) + + @dataclass + class CustomFormats: + mismatched_rows = ["green"] # noqa: RUF012 + matched_rows = ["yellow"] # noqa: RUF012 + mismatched_cells = ["purple", "bold"] # noqa: RUF012 + matched_cells = ["cyan"] # noqa: RUF012 + + with warnings.catch_warnings(record=True) as w: + try: + assert_basic_rows_equality(df1.collect(), df2.collect(), formats=CustomFormats()) + # should not reach the line below due to the raised error. + # pytest.raises does not work as expected since then we cannot verify the warning. + assert False + except DataFramesNotEqualError: + warnings.simplefilter("always") + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + assert "Using an arbitrary dataclass is deprecated." in str(w[-1].message) + + +def test_invalid_value_in_default_formats(): + @dataclass + class InvalidFormats: + mismatched_rows = ["green"] # noqa: RUF012 + matched_rows = ["yellow"] # noqa: RUF012 + mismatched_cells = ["purple", "invalid"] # noqa: RUF012 + matched_cells = ["cyan"] # noqa: RUF012 + + with pytest.raises(ValueError): + FormattingConfig._from_arbitrary_dataclass(InvalidFormats()) diff --git a/tests/test_readme_examples.py b/tests/test_readme_examples.py index a597ffd..fe28ee0 100644 --- a/tests/test_readme_examples.py +++ b/tests/test_readme_examples.py @@ -1,21 +1,27 @@ -import pytest +from __future__ import annotations -from chispa import * import pyspark.sql.functions as F +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType, StructField, StructType + +from chispa import ( + ColumnsNotEqualError, + DataFramesNotEqualError, + assert_approx_column_equality, + assert_approx_df_equality, + assert_basic_rows_equality, + assert_column_equality, + assert_df_equality, +) from chispa.schema_comparer import SchemasNotEqualError -from pyspark.sql.types import * def remove_non_word_characters(col): return F.regexp_replace(col, "[^\\w\\s]+", "") -from pyspark.sql import SparkSession - -spark = (SparkSession.builder - .master("local") - .appName("chispa") - .getOrCreate()) +spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() def describe_column_equality(): @@ -24,116 +30,98 @@ def test_removes_non_word_characters_short(): ("jo&&se", "jose"), ("**li**", "li"), ("#::luisa", "luisa"), - (None, None) + (None, None), ] - df = spark.createDataFrame(data, ["name", "expected_name"])\ - .withColumn("clean_name", remove_non_word_characters(F.col("name"))) + df = spark.createDataFrame(data, ["name", "expected_name"]).withColumn( + "clean_name", remove_non_word_characters(F.col("name")) + ) assert_column_equality(df, "clean_name", "expected_name") - def test_remove_non_word_characters_nice_error(): data = [ ("matt7", "matt"), ("bill&", "bill"), ("isabela*", "isabela"), - (None, None) + (None, None), ] - df = spark.createDataFrame(data, ["name", "expected_name"])\ - .withColumn("clean_name", remove_non_word_characters(F.col("name"))) + df = spark.createDataFrame(data, ["name", "expected_name"]).withColumn( + "clean_name", remove_non_word_characters(F.col("name")) + ) # assert_column_equality(df, "clean_name", "expected_name") - with pytest.raises(ColumnsNotEqualError) as e_info: + with pytest.raises(ColumnsNotEqualError): assert_column_equality(df, "clean_name", "expected_name") - def describe_dataframe_equality(): def test_remove_non_word_characters_long(): - source_data = [ - ("jo&&se",), - ("**li**",), - ("#::luisa",), - (None,) - ] + source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)] source_df = spark.createDataFrame(source_data, ["name"]) - actual_df = source_df.withColumn( - "clean_name", - remove_non_word_characters(F.col("name")) - ) + actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) expected_data = [ ("jo&&se", "jose"), ("**li**", "li"), ("#::luisa", "luisa"), - (None, None) + (None, None), ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) assert_df_equality(actual_df, expected_df) - def test_remove_non_word_characters_long_error(): - source_data = [ - ("matt7",), - ("bill&",), - ("isabela*",), - (None,) - ] + source_data = [("matt7",), ("bill&",), ("isabela*",), (None,)] source_df = spark.createDataFrame(source_data, ["name"]) - actual_df = source_df.withColumn( - "clean_name", - remove_non_word_characters(F.col("name")) - ) + actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) expected_data = [ ("matt7", "matt"), ("bill&", "bill"), ("isabela*", "isabela"), - (None, None) + (None, None), ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) # assert_df_equality(actual_df, expected_df) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_df_equality(actual_df, expected_df) - def ignore_row_order(): df1 = spark.createDataFrame([(1,), (2,), (3,)], ["some_num"]) df2 = spark.createDataFrame([(2,), (1,), (3,)], ["some_num"]) # assert_df_equality(df1, df2) assert_df_equality(df1, df2, ignore_row_order=True) - def ignore_column_order(): df1 = spark.createDataFrame([(1, 7), (2, 8), (3, 9)], ["num1", "num2"]) df2 = spark.createDataFrame([(7, 1), (8, 2), (9, 3)], ["num2", "num1"]) assert_df_equality(df1, df2, ignore_column_order=True) - def ignore_nullable_property(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) df1 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s1) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + ]) df2 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s2) assert_df_equality(df1, df2, ignore_nullable=True) - def ignore_nullable_property_array(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("coords", ArrayType(DoubleType(), True), True),]) + StructField("name", StringType(), True), + StructField("coords", ArrayType(DoubleType(), True), True), + ]) df1 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s1) s2 = StructType([ - StructField("name", StringType(), True), - StructField("coords", ArrayType(DoubleType(), False), True),]) + StructField("name", StringType(), True), + StructField("coords", ArrayType(DoubleType(), False), True), + ]) df2 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s2) assert_df_equality(df1, df2, ignore_nullable=True) - def consider_nan_values_equal(): - data1 = [(float('nan'), "jose"), (2.0, "li")] + data1 = [(float("nan"), "jose"), (2.0, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) - data2 = [(float('nan'), "jose"), (2.0, "li")] + data2 = [(float("nan"), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, allow_nan_equality=True) @@ -152,8 +140,8 @@ def it_prints_underline_message(): ("li", 99), ("rick", 66), ] - df2 = spark.createDataFrame(data, ["firstname", "age"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + df2 = spark.createDataFrame(data, ["firstname", "age"]) + with pytest.raises(DataFramesNotEqualError): assert_df_equality(df1, df2, underline_cells=True) def it_shows_assert_basic_rows_equality(my_formats): @@ -173,110 +161,60 @@ def it_shows_assert_basic_rows_equality(my_formats): ] df2 = spark.createDataFrame(data, ["firstname", "age"]) # assert_basic_rows_equality(df1.collect(), df2.collect(), formats=my_formats) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_basic_rows_equality(df1.collect(), df2.collect(), underline_cells=True) + def describe_assert_approx_column_equality(): def test_approx_col_equality_same(): - data = [ - (1.1, 1.1), - (2.2, 2.15), - (3.3, 3.37), - (None, None) - ] + data = [(1.1, 1.1), (2.2, 2.15), (3.3, 3.37), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) assert_approx_column_equality(df, "num1", "num2", 0.1) - def test_approx_col_equality_different(): - data = [ - (1.1, 1.1), - (2.2, 2.15), - (3.3, 5.0), - (None, None) - ] + data = [(1.1, 1.1), (2.2, 2.15), (3.3, 5.0), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) - with pytest.raises(ColumnsNotEqualError) as e_info: + with pytest.raises(ColumnsNotEqualError): assert_approx_column_equality(df, "num1", "num2", 0.1) - def test_approx_df_equality_same(): - data1 = [ - (1.1, "a"), - (2.2, "b"), - (3.3, "c"), - (None, None) - ] + data1 = [(1.1, "a"), (2.2, "b"), (3.3, "c"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "letter"]) - data2 = [ - (1.05, "a"), - (2.13, "b"), - (3.3, "c"), - (None, None) - ] + data2 = [(1.05, "a"), (2.13, "b"), (3.3, "c"), (None, None)] df2 = spark.createDataFrame(data2, ["num", "letter"]) assert_approx_df_equality(df1, df2, 0.1) - def test_approx_df_equality_different(): - data1 = [ - (1.1, "a"), - (2.2, "b"), - (3.3, "c"), - (None, None) - ] + data1 = [(1.1, "a"), (2.2, "b"), (3.3, "c"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "letter"]) - data2 = [ - (1.1, "a"), - (5.0, "b"), - (3.3, "z"), - (None, None) - ] + data2 = [(1.1, "a"), (5.0, "b"), (3.3, "z"), (None, None)] df2 = spark.createDataFrame(data2, ["num", "letter"]) # assert_approx_df_equality(df1, df2, 0.1) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_approx_df_equality(df1, df2, 0.1) def describe_schema_mismatch_messages(): def test_schema_mismatch_message(): - data1 = [ - (1, "a"), - (2, "b"), - (3, "c"), - (None, None) - ] + data1 = [(1, "a"), (2, "b"), (3, "c"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "letter"]) - data2 = [ - (1, 6), - (2, 7), - (3, 8), - (None, None) - ] + data2 = [(1, 6), (2, 7), (3, 8), (None, None)] df2 = spark.createDataFrame(data2, ["num", "num2"]) - with pytest.raises(SchemasNotEqualError) as e_info: + with pytest.raises(SchemasNotEqualError): assert_df_equality(df1, df2) def test_remove_non_word_characters_long_error(my_chispa): - source_data = [ - ("matt7",), - ("bill&",), - ("isabela*",), - (None,) - ] + source_data = [("matt7",), ("bill&",), ("isabela*",), (None,)] source_df = spark.createDataFrame(source_data, ["name"]) - actual_df = source_df.withColumn( - "clean_name", - remove_non_word_characters(F.col("name")) - ) + actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) expected_data = [ ("matt7", "matt"), ("bill&", "bill"), ("isabela*", "isabela"), - (None, None) + (None, None), ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) # my_chispa.assert_df_equality(actual_df, expected_df) - with pytest.raises(DataFramesNotEqualError) as e_info: - assert_df_equality(actual_df, expected_df) \ No newline at end of file + with pytest.raises(DataFramesNotEqualError): + assert_df_equality(actual_df, expected_df) diff --git a/tests/test_row_comparer.py b/tests/test_row_comparer.py index 2f6c899..fe9b48c 100644 --- a/tests/test_row_comparer.py +++ b/tests/test_row_comparer.py @@ -1,28 +1,28 @@ -import pytest +from __future__ import annotations -from .spark import * -from chispa.row_comparer import * from pyspark.sql import Row +from chispa.row_comparer import are_rows_approx_equal, are_rows_equal, are_rows_equal_enhanced + def test_are_rows_equal(): - assert are_rows_equal(Row("bob", "jose"), Row("li", "li")) == False - assert are_rows_equal(Row("luisa", "laura"), Row("luisa", "laura")) == True - assert are_rows_equal(Row(None, None), Row(None, None)) == True + assert are_rows_equal(Row("bob", "jose"), Row("li", "li")) is False + assert are_rows_equal(Row("luisa", "laura"), Row("luisa", "laura")) is True + assert are_rows_equal(Row(None, None), Row(None, None)) is True + def test_are_rows_equal_enhanced(): - assert are_rows_equal_enhanced(Row(n1 = "bob", n2 = "jose"), Row(n1 = "li", n2 = "li"), False) == False - assert are_rows_equal_enhanced(Row(n1 = "luisa", n2 = "laura"), Row(n1 = "luisa", n2 = "laura"), False) == True - assert are_rows_equal_enhanced(Row(n1 = None, n2 = None), Row(n1 = None, n2 = None), False) == True + assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), False) is False + assert are_rows_equal_enhanced(Row(n1="luisa", n2="laura"), Row(n1="luisa", n2="laura"), False) is True + assert are_rows_equal_enhanced(Row(n1=None, n2=None), Row(n1=None, n2=None), False) is True - assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), True) == False - assert are_rows_equal_enhanced(Row(n1=float('nan'), n2="jose"), Row(n1=float('nan'), n2="jose"), True) == True - assert are_rows_equal_enhanced(Row(n1=float('nan'), n2="jose"), Row(n1="hi", n2="jose"), True) == False + assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), True) is False + assert are_rows_equal_enhanced(Row(n1=float("nan"), n2="jose"), Row(n1=float("nan"), n2="jose"), True) is True + assert are_rows_equal_enhanced(Row(n1=float("nan"), n2="jose"), Row(n1="hi", n2="jose"), True) is False def test_are_rows_approx_equal(): - assert are_rows_approx_equal(Row(num = 1.1, first_name = "li"), Row(num = 1.05, first_name = "li"), 0.1) == True - assert are_rows_approx_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.0, first_name = "laura"), 0.1) == True - assert are_rows_approx_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.9, first_name = "laura"), 0.1) == False - assert are_rows_approx_equal(Row(num = None, first_name = None), Row(num = None, first_name = None), 0.1) == True - + assert are_rows_approx_equal(Row(num=1.1, first_name="li"), Row(num=1.05, first_name="li"), 0.1) is True + assert are_rows_approx_equal(Row(num=5.0, first_name="laura"), Row(num=5.0, first_name="laura"), 0.1) is True + assert are_rows_approx_equal(Row(num=5.0, first_name="laura"), Row(num=5.9, first_name="laura"), 0.1) is False + assert are_rows_approx_equal(Row(num=None, first_name=None), Row(num=None, first_name=None), 0.1) is True diff --git a/tests/test_rows_comparer.py b/tests/test_rows_comparer.py index 73ee4f7..34b06fc 100644 --- a/tests/test_rows_comparer.py +++ b/tests/test_rows_comparer.py @@ -1,10 +1,10 @@ +from __future__ import annotations + import pytest -from .spark import * -from chispa import * -from chispa.rows_comparer import assert_basic_rows_equality -from chispa import DataFramesNotEqualError -import math +from chispa import DataFramesNotEqualError, assert_basic_rows_equality + +from .spark import spark def describe_assert_basic_rows_equality(): @@ -13,7 +13,7 @@ def it_throws_with_row_mismatches(): df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_basic_rows_equality(df1.collect(), df2.collect()) def it_throws_when_rows_have_different_lengths(): @@ -21,7 +21,7 @@ def it_throws_when_rows_have_different_lengths(): df1 = spark.createDataFrame(data1, ["num", "expected_name"]) data2 = [(1, "jose"), (2, "li"), (3, "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) - with pytest.raises(DataFramesNotEqualError) as e_info: + with pytest.raises(DataFramesNotEqualError): assert_basic_rows_equality(df1.collect(), df2.collect()) def it_works_when_rows_are_the_same(): @@ -30,4 +30,3 @@ def it_works_when_rows_are_the_same(): data2 = [(1, "jose"), (2, "li"), (3, "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) assert_basic_rows_equality(df1.collect(), df2.collect()) - diff --git a/tests/test_schema_comparer.py b/tests/test_schema_comparer.py index f679fb0..eee7d9b 100644 --- a/tests/test_schema_comparer.py +++ b/tests/test_schema_comparer.py @@ -1,164 +1,176 @@ +from __future__ import annotations + import pytest +from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType, StructField, StructType -from pyspark.sql.types import * -from chispa.schema_comparer import * +from chispa.schema_comparer import ( + SchemasNotEqualError, + are_schemas_equal_ignore_nullable, + are_structfields_equal, + assert_schema_equality, + assert_schema_equality_ignore_nullable, +) def describe_assert_schema_equality(): def it_does_nothing_when_equal(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) assert_schema_equality(s1, s2) - def it_throws_when_column_names_differ(): s1 = StructType([ - StructField("HAHA", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("HAHA", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) - with pytest.raises(SchemasNotEqualError) as e_info: + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) + with pytest.raises(SchemasNotEqualError): assert_schema_equality(s1, s2) - def it_throws_when_schema_lengths_differ(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("fav_number", IntegerType(), True)]) - with pytest.raises(SchemasNotEqualError) as e_info: + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("fav_number", IntegerType(), True), + ]) + with pytest.raises(SchemasNotEqualError): assert_schema_equality(s1, s2) def describe_assert_schema_equality_ignore_nullable(): def it_has_good_error_messages_for_different_sized_schemas(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), False), - StructField("age", IntegerType(), True), - StructField("something", IntegerType(), True), - StructField("else", IntegerType(), True) + StructField("name", StringType(), False), + StructField("age", IntegerType(), True), + StructField("something", IntegerType(), True), + StructField("else", IntegerType(), True), ]) - with pytest.raises(SchemasNotEqualError) as e_info: + with pytest.raises(SchemasNotEqualError): assert_schema_equality_ignore_nullable(s1, s2) - def it_does_nothing_when_equal(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) assert_schema_equality_ignore_nullable(s1, s2) - def it_does_nothing_when_only_nullable_flag_is_different(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + ]) assert_schema_equality_ignore_nullable(s1, s2) def describe_are_schemas_equal_ignore_nullable(): def it_returns_true_when_only_nullable_flag_is_different(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("coords", ArrayType(DoubleType(), True), True), + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("coords", ArrayType(DoubleType(), True), True), ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False), - StructField("coords", ArrayType(DoubleType(), True), False), + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + StructField("coords", ArrayType(DoubleType(), True), False), ]) - assert are_schemas_equal_ignore_nullable(s1, s2) == True - + assert are_schemas_equal_ignore_nullable(s1, s2) is True def it_returns_true_when_only_nullable_flag_is_different_within_array_element(): s1 = StructType([StructField("coords", ArrayType(DoubleType(), True), True)]) s2 = StructType([StructField("coords", ArrayType(DoubleType(), False), True)]) - assert are_schemas_equal_ignore_nullable(s1, s2) == True + assert are_schemas_equal_ignore_nullable(s1, s2) is True def it_returns_true_when_only_nullable_flag_is_different_within_nested_array_element(): s1 = StructType([StructField("coords", ArrayType(ArrayType(DoubleType(), True), True), True)]) s2 = StructType([StructField("coords", ArrayType(ArrayType(DoubleType(), False), True), True)]) - assert are_schemas_equal_ignore_nullable(s1, s2) == True - + assert are_schemas_equal_ignore_nullable(s1, s2) is True def it_returns_false_when_the_element_type_is_different_within_array(): s1 = StructType([StructField("coords", ArrayType(DoubleType(), True), True)]) s2 = StructType([StructField("coords", ArrayType(IntegerType(), True), True)]) - assert are_schemas_equal_ignore_nullable(s1, s2) == False - + assert are_schemas_equal_ignore_nullable(s1, s2) is False def it_returns_false_when_column_names_differ(): s1 = StructType([ - StructField("blah", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("blah", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False)]) - assert are_schemas_equal_ignore_nullable(s1, s2) == False + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + ]) + assert are_schemas_equal_ignore_nullable(s1, s2) is False def it_returns_false_when_columns_have_different_order(): s1 = StructType([ - StructField("blah", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("blah", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("age", IntegerType(), False), - StructField("blah", StringType(), True)]) - assert are_schemas_equal_ignore_nullable(s1, s2) == False + StructField("age", IntegerType(), False), + StructField("blah", StringType(), True), + ]) + assert are_schemas_equal_ignore_nullable(s1, s2) is False def describe_are_structfields_equal(): def it_returns_true_when_only_nullable_flag_is_different_within_array_element(): s1 = StructField("coords", ArrayType(DoubleType(), True), True) s2 = StructField("coords", ArrayType(DoubleType(), False), True) - assert are_structfields_equal(s1, s2, True) == True - + assert are_structfields_equal(s1, s2, True) is True def it_returns_false_when_the_element_type_is_different_within_array(): s1 = StructField("coords", ArrayType(DoubleType(), True), True) s2 = StructField("coords", ArrayType(IntegerType(), True), True) - assert are_structfields_equal(s1, s2, True) == False - + assert are_structfields_equal(s1, s2, True) is False def it_returns_true_when_the_element_type_is_same_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) - assert are_structfields_equal(s1, s2, True) == True - + assert are_structfields_equal(s1, s2, True) is True def it_returns_false_when_the_element_type_is_different_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("hello", IntegerType(), True)]), True) - assert are_structfields_equal(s1, s2, True) == False - + assert are_structfields_equal(s1, s2, True) is False def it_returns_false_when_the_element_name_is_different_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("world", DoubleType(), True)]), True) - assert are_structfields_equal(s1, s2, True) == False - - + assert are_structfields_equal(s1, s2, True) is False + def it_returns_true_when_different_nullability_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("hello", DoubleType(), False)]), True) - assert are_structfields_equal(s1, s2, True) == True + assert are_structfields_equal(s1, s2, True) is True + def it_returns_false_when_metadata_differs(): s1 = StructField("coords", StringType(), True, {"hi": "whatever"}) s2 = StructField("coords", StringType(), True, {"hi": "no"}) diff --git a/tests/test_structfield_comparer.py b/tests/test_structfield_comparer.py index df40aea..a6d181d 100644 --- a/tests/test_structfield_comparer.py +++ b/tests/test_structfield_comparer.py @@ -1,56 +1,57 @@ -import pytest +from __future__ import annotations + +from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType from chispa.structfield_comparer import are_structfields_equal -from pyspark.sql.types import * def describe_are_structfields_equal(): def it_returns_true_when_structfields_are_the_same(): sf1 = StructField("hi", IntegerType(), True) sf2 = StructField("hi", IntegerType(), True) - assert are_structfields_equal(sf1, sf2) == True + assert are_structfields_equal(sf1, sf2) is True def it_returns_false_when_column_names_are_different(): sf1 = StructField("hello", IntegerType(), True) sf2 = StructField("hi", IntegerType(), True) - assert are_structfields_equal(sf1, sf2) == False + assert are_structfields_equal(sf1, sf2) is False def it_returns_false_when_nullable_property_is_different(): sf1 = StructField("hi", IntegerType(), False) sf2 = StructField("hi", IntegerType(), True) - assert are_structfields_equal(sf1, sf2) == False + assert are_structfields_equal(sf1, sf2) is False def it_can_perform_nullability_insensitive_comparisons(): sf1 = StructField("hi", IntegerType(), False) sf2 = StructField("hi", IntegerType(), True) - assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == True + assert are_structfields_equal(sf1, sf2, ignore_nullability=True) is True def it_returns_true_when_nested_types_are_the_same(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) - assert are_structfields_equal(sf1, sf2) == True + assert are_structfields_equal(sf1, sf2) is True def it_returns_false_when_nested_names_are_different(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("developer", IntegerType(), False)]), False) - assert are_structfields_equal(sf1, sf2) == False + assert are_structfields_equal(sf1, sf2) is False def it_returns_false_when_nested_types_are_different(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("world", DoubleType(), False)]), False) - assert are_structfields_equal(sf1, sf2) == False + assert are_structfields_equal(sf1, sf2) is False def it_returns_false_when_nested_types_have_different_nullability(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("world", IntegerType(), True)]), False) - assert are_structfields_equal(sf1, sf2) == False + assert are_structfields_equal(sf1, sf2) is False def it_returns_false_when_nested_types_are_different_with_ignore_nullable_true(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("developer", IntegerType(), False)]), False) - assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == False + assert are_structfields_equal(sf1, sf2, ignore_nullability=True) is False def it_returns_true_when_nested_types_have_different_nullability_with_ignore_null_true(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("world", IntegerType(), True)]), False) - assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == True \ No newline at end of file + assert are_structfields_equal(sf1, sf2, ignore_nullability=True) is True diff --git a/tests/test_terminal_str_formatter.py b/tests/test_terminal_str_formatter.py deleted file mode 100644 index 7f73d4d..0000000 --- a/tests/test_terminal_str_formatter.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from chispa.terminal_str_formatter import format_string - - -def test_it_can_make_a_blue_string(): - print(format_string("hi", ["bold", "blink"])) - - -def test_it_works_with_no_formats(): - print(format_string("hi", []))