diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 66d2667..7b4c171 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,9 @@ jobs: - name: Check lock file run: poetry lock --check + - name: Run pre-commit hooks + run: poetry run pre-commit run -a + test: runs-on: ubuntu-latest strategy: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..91a6304 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.5.2" + hooks: + - id: ruff + args: [--exit-non-zero-on-fix] + - id: ruff-format diff --git a/chispa/__init__.py b/chispa/__init__.py index 0c59b1b..231faae 100644 --- a/chispa/__init__.py +++ b/chispa/__init__.py @@ -9,15 +9,15 @@ # We need to add PySpark, try use findspark, or failback to the "manually" find it try: import findspark + findspark.init() except ImportError: try: - spark_home = os.environ['SPARK_HOME'] - sys.path.append(os.path.join(spark_home, 'python')) - py4j_src_zip = glob(os.path.join(spark_home, 'python', 'lib', 'py4j-*-src.zip')) + spark_home = os.environ["SPARK_HOME"] + sys.path.append(os.path.join(spark_home, "python")) + py4j_src_zip = glob(os.path.join(spark_home, "python", "lib", "py4j-*-src.zip")) if len(py4j_src_zip) == 0: - raise ValueError('py4j source archive not found in %s' - % os.path.join(spark_home, 'python', 'lib')) + raise ValueError("py4j source archive not found in %s" % os.path.join(spark_home, "python", "lib")) else: py4j_src_zip = sorted(py4j_src_zip)[::-1] sys.path.append(py4j_src_zip[0]) @@ -25,26 +25,46 @@ print("Can't find Apache Spark. Please set environment variable SPARK_HOME to root of installation!") exit(-1) -from .dataframe_comparer import DataFramesNotEqualError, assert_df_equality, assert_approx_df_equality -from .column_comparer import ColumnsNotEqualError, assert_column_equality, assert_approx_column_equality +from .dataframe_comparer import ( + DataFramesNotEqualError, + assert_df_equality, + assert_approx_df_equality, +) +from .column_comparer import ( + ColumnsNotEqualError, + assert_column_equality, + assert_approx_column_equality, +) from .rows_comparer import assert_basic_rows_equality from chispa.default_formats import DefaultFormats -class Chispa(): + +class Chispa: def __init__(self, formats=DefaultFormats(), default_output=None): self.formats = formats self.default_outputs = default_output - def assert_df_equality(self, df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False): + def assert_df_equality( + self, + df1, + df2, + ignore_nullable=False, + transforms=None, + allow_nan_equality=False, + ignore_column_order=False, + ignore_row_order=False, + underline_cells=False, + ignore_metadata=False, + ): return assert_df_equality( - df1, - df2, - ignore_nullable, - transforms, + df1, + df2, + ignore_nullable, + transforms, allow_nan_equality, - ignore_column_order, - ignore_row_order, - underline_cells, + ignore_column_order, + ignore_row_order, + underline_cells, ignore_metadata, - self.formats) \ No newline at end of file + self.formats, + ) diff --git a/chispa/bcolors.py b/chispa/bcolors.py index bbbb930..3c531f2 100644 --- a/chispa/bcolors.py +++ b/chispa/bcolors.py @@ -1,32 +1,32 @@ class bcolors: - NC = '\033[0m' # No Color, reset all + NC = "\033[0m" # No Color, reset all - Bold = '\033[1m' - Underlined = '\033[4m' - Blink = '\033[5m' - Inverted = '\033[7m' - Hidden = '\033[8m' + Bold = "\033[1m" + Underlined = "\033[4m" + Blink = "\033[5m" + Inverted = "\033[7m" + Hidden = "\033[8m" - Black = '\033[30m' - Red = '\033[31m' - Green = '\033[32m' - Yellow = '\033[33m' - Blue = '\033[34m' - Purple = '\033[35m' - Cyan = '\033[36m' - LightGray = '\033[37m' - DarkGray = '\033[30m' - LightRed = '\033[31m' - LightGreen = '\033[32m' - LightYellow = '\033[93m' - LightBlue = '\033[34m' - LightPurple = '\033[35m' - LightCyan = '\033[36m' - White = '\033[97m' + Black = "\033[30m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Purple = "\033[35m" + Cyan = "\033[36m" + LightGray = "\033[37m" + DarkGray = "\033[30m" + LightRed = "\033[31m" + LightGreen = "\033[32m" + LightYellow = "\033[93m" + LightBlue = "\033[34m" + LightPurple = "\033[35m" + LightCyan = "\033[36m" + White = "\033[97m" # Style - Bold = '\033[1m' - Underline = '\033[4m' + Bold = "\033[1m" + Underline = "\033[4m" def blue(s: str) -> str: diff --git a/chispa/column_comparer.py b/chispa/column_comparer.py index b201b2d..363d050 100644 --- a/chispa/column_comparer.py +++ b/chispa/column_comparer.py @@ -3,8 +3,9 @@ class ColumnsNotEqualError(Exception): - """The columns are not equal""" - pass + """The columns are not equal""" + + pass def assert_column_equality(df, col_name1, col_name2): @@ -35,11 +36,11 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision): first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed # when one is None and the other isn't, they're not equal - if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None): + if (elements[0] is None and elements[1] is not None) or (elements[0] is not None and elements[1] is None): all_rows_equal = False t.add_row([str(elements[0]), str(elements[1])]) # when both are None, they're equal - elif elements[0] == None and elements[1] == None: + elif elements[0] is None and elements[1] is None: t.add_row([first, second]) # when the diff is less than the threshhold, they're approximately equal elif abs(elements[0] - elements[1]) < precision: @@ -50,4 +51,3 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision): t.add_row([str(elements[0]), str(elements[1])]) if all_rows_equal == False: raise ColumnsNotEqualError("\n" + t.get_string()) - diff --git a/chispa/dataframe_comparer.py b/chispa/dataframe_comparer.py index 6b2fc67..d95564b 100644 --- a/chispa/dataframe_comparer.py +++ b/chispa/dataframe_comparer.py @@ -1,17 +1,31 @@ from chispa.schema_comparer import assert_schema_equality from chispa.default_formats import DefaultFormats -from chispa.rows_comparer import assert_basic_rows_equality, assert_generic_rows_equality +from chispa.rows_comparer import ( + assert_basic_rows_equality, + assert_generic_rows_equality, +) from chispa.row_comparer import are_rows_equal_enhanced, are_rows_approx_equal from functools import reduce class DataFramesNotEqualError(Exception): - """The DataFrames are not equal""" - pass + """The DataFrames are not equal""" + pass -def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False, formats=DefaultFormats()): + +def assert_df_equality( + df1, + df2, + ignore_nullable=False, + transforms=None, + allow_nan_equality=False, + ignore_column_order=False, + ignore_row_order=False, + underline_cells=False, + ignore_metadata=False, + formats=DefaultFormats(), +): if transforms is None: transforms = [] if ignore_column_order: @@ -23,10 +37,20 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata) if allow_nan_equality: assert_generic_rows_equality( - df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells, formats=formats) + df1.collect(), + df2.collect(), + are_rows_equal_enhanced, + [True], + underline_cells=underline_cells, + formats=formats, + ) else: assert_basic_rows_equality( - df1.collect(), df2.collect(), underline_cells=underline_cells, formats=formats) + df1.collect(), + df2.collect(), + underline_cells=underline_cells, + formats=formats, + ) def are_dfs_equal(df1, df2): @@ -37,8 +61,17 @@ def are_dfs_equal(df1, df2): return True -def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transforms=None, allow_nan_equality=False, - ignore_column_order=False, ignore_row_order=False, formats=DefaultFormats()): +def assert_approx_df_equality( + df1, + df2, + precision, + ignore_nullable=False, + transforms=None, + allow_nan_equality=False, + ignore_column_order=False, + ignore_row_order=False, + formats=DefaultFormats(), +): if transforms is None: transforms = [] if ignore_column_order: @@ -49,7 +82,13 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transf df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) assert_schema_equality(df1.schema, df2.schema, ignore_nullable) if precision != 0: - assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality], formats) + assert_generic_rows_equality( + df1.collect(), + df2.collect(), + are_rows_approx_equal, + [precision, allow_nan_equality], + formats, + ) elif allow_nan_equality: assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], formats) else: diff --git a/chispa/default_formats.py b/chispa/default_formats.py index 60654d7..433db11 100644 --- a/chispa/default_formats.py +++ b/chispa/default_formats.py @@ -1,5 +1,6 @@ from dataclasses import dataclass + @dataclass class DefaultFormats: mismatched_rows = ["red"] diff --git a/chispa/number_helpers.py b/chispa/number_helpers.py index 9f8713c..818d3f5 100644 --- a/chispa/number_helpers.py +++ b/chispa/number_helpers.py @@ -13,4 +13,4 @@ def nan_safe_equality(x, y) -> bool: def nan_safe_approx_equality(x, y, precision) -> bool: - return (abs(x-y)<=precision) or (isnan(x) and isnan(y)) \ No newline at end of file + return (abs(x - y) <= precision) or (isnan(x) and isnan(y)) diff --git a/chispa/row_comparer.py b/chispa/row_comparer.py index 886d2be..781cb1b 100644 --- a/chispa/row_comparer.py +++ b/chispa/row_comparer.py @@ -16,7 +16,7 @@ def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool: d2 = r2.asDict() if allow_nan_equality: for key in d1.keys() & d2.keys(): - if not(nan_safe_equality(d1[key], d2[key])): + if not (nan_safe_equality(d1[key], d2[key])): return False return True else: @@ -33,13 +33,12 @@ def are_rows_approx_equal(r1: Row, r2: Row, precision: float, allow_nan_equality allEqual = True for key in d1.keys() & d2.keys(): if isinstance(d1[key], float) and isinstance(d2[key], float): - if allow_nan_equality and not(nan_safe_approx_equality(d1[key], d2[key], precision)): + if allow_nan_equality and not (nan_safe_approx_equality(d1[key], d2[key], precision)): allEqual = False - elif not(allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])): + elif not (allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])): allEqual = False elif abs(d1[key] - d2[key]) > precision: allEqual = False elif d1[key] != d2[key]: allEqual = False return allEqual - diff --git a/chispa/rows_comparer.py b/chispa/rows_comparer.py index e52cc45..3031a54 100644 --- a/chispa/rows_comparer.py +++ b/chispa/rows_comparer.py @@ -2,8 +2,6 @@ from prettytable import PrettyTable from chispa.bcolors import * import chispa -from pyspark.sql.types import Row -from typing import List from chispa.terminal_str_formatter import format_string from chispa.default_formats import DefaultFormats @@ -41,7 +39,14 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=Defa raise chispa.DataFramesNotEqualError("\n" + t.get_string()) -def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False, formats=DefaultFormats()): +def assert_generic_rows_equality( + rows1, + rows2, + row_equality_fun, + row_equality_fun_args, + underline_cells=False, + formats=DefaultFormats(), +): df1_rows = rows1 df2_rows = rows2 zipped = list(zip_longest(df1_rows, df2_rows)) @@ -51,12 +56,18 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu # rows are not equal when one is None and the other isn't if (r1 is not None and r2 is None) or (r2 is not None and r1 is None): all_rows_equal = False - t.add_row([format_string(r1, formats.mismatched_rows), format_string(r2, formats.mismatched_rows)]) + t.add_row([ + format_string(r1, formats.mismatched_rows), + format_string(r2, formats.mismatched_rows), + ]) # rows are equal elif row_equality_fun(r1, r2, *row_equality_fun_args): r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__)) r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__)) - t.add_row([format_string(r1_string, formats.matched_rows), format_string(r2_string, formats.matched_rows)]) + t.add_row([ + format_string(r1_string, formats.matched_rows), + format_string(r2_string, formats.matched_rows), + ]) # otherwise, rows aren't equal else: r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) diff --git a/chispa/schema_comparer.py b/chispa/schema_comparer.py index 4bb77e2..8f50c15 100644 --- a/chispa/schema_comparer.py +++ b/chispa/schema_comparer.py @@ -4,8 +4,9 @@ class SchemasNotEqualError(Exception): - """The schemas are not equal""" - pass + """The schemas are not equal""" + + pass def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False): @@ -51,7 +52,6 @@ def assert_basic_schema_equality(s1, s2): raise SchemasNotEqualError("\n" + t.get_string()) - # deprecate this. ignore_nullable should be a flag. def assert_schema_equality_ignore_nullable(s1, s2): if not are_schemas_equal_ignore_nullable(s1, s2): @@ -101,9 +101,9 @@ def are_datatypes_equal_ignore_nullable(dt1, dt2): """ if dt1.typeName() == dt2.typeName(): # Account for array types by inspecting elementType. - if dt1.typeName() == 'array': + if dt1.typeName() == "array": return are_datatypes_equal_ignore_nullable(dt1.elementType, dt2.elementType) - elif dt1.typeName() == 'struct': + elif dt1.typeName() == "struct": return are_schemas_equal_ignore_nullable(dt1, dt2) else: return True diff --git a/chispa/structfield_comparer.py b/chispa/structfield_comparer.py index dc448a1..925dbbc 100644 --- a/chispa/structfield_comparer.py +++ b/chispa/structfield_comparer.py @@ -1 +1,3 @@ -from chispa.schema_comparer import are_structfields_equal \ No newline at end of file +from chispa.schema_comparer import are_structfields_equal + +__all__ = "are_structfields_equal" diff --git a/chispa/terminal_str_formatter.py b/chispa/terminal_str_formatter.py index ef5ace6..69174d7 100644 --- a/chispa/terminal_str_formatter.py +++ b/chispa/terminal_str_formatter.py @@ -1,30 +1,30 @@ def format_string(input, formats): formatting = { - "nc": '\033[0m', # No Color, reset all - "bold": '\033[1m', - "underline": '\033[4m', - "blink": '\033[5m', - "blue": '\033[34m', - "white": '\033[97m', - "red": '\033[31m', - "invert": '\033[7m', - "hide": '\033[8m', - "black": '\033[30m', - "green": '\033[32m', - "yellow": '\033[33m', - "purple": '\033[35m', - "cyan": '\033[36m', - "light_gray": '\033[37m', - "dark_gray": '\033[30m', - "light_red": '\033[31m', - "light_green": '\033[32m', - "light_yellow": '\033[93m', - "light_blue": '\033[34m', - "light_purple": '\033[35m', - "light_cyan": '\033[36m', + "nc": "\033[0m", # No Color, reset all + "bold": "\033[1m", + "underline": "\033[4m", + "blink": "\033[5m", + "blue": "\033[34m", + "white": "\033[97m", + "red": "\033[31m", + "invert": "\033[7m", + "hide": "\033[8m", + "black": "\033[30m", + "green": "\033[32m", + "yellow": "\033[33m", + "purple": "\033[35m", + "cyan": "\033[36m", + "light_gray": "\033[37m", + "dark_gray": "\033[30m", + "light_red": "\033[31m", + "light_green": "\033[32m", + "light_yellow": "\033[93m", + "light_blue": "\033[34m", + "light_purple": "\033[35m", + "light_cyan": "\033[36m", } formatted = input for format in formats: s = formatting[format] formatted = s + str(formatted) + s - return formatting["nc"] + str(formatted) + formatting["nc"] \ No newline at end of file + return formatting["nc"] + str(formatted) + formatting["nc"] diff --git a/poetry.lock b/poetry.lock index 87cb1cb..8612623 100644 --- a/poetry.lock +++ b/poetry.lock @@ -43,6 +43,17 @@ files = [ {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, ] +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -234,6 +245,17 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -248,6 +270,22 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "filelock" +version = "3.15.4" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +typing = ["typing-extensions (>=4.8)"] + [[package]] name = "findspark" version = "1.4.2" @@ -291,6 +329,20 @@ files = [ astunparse = {version = ">=1.6", markers = "python_version < \"3.9\""} colorama = ">=0.4" +[[package]] +name = "identify" +version = "2.6.0" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.6.0-py2.py3-none-any.whl", hash = "sha256:e79ae4406387a9d300332b5fd366d8994f1525e8414984e1a59e058b2eda2dd0"}, + {file = "identify-2.6.0.tar.gz", hash = "sha256:cb171c685bdc31bcc4c1734698736a7d5b6c8bf2e0c15117f4d469c8640ae5cf"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.7" @@ -653,6 +705,17 @@ files = [ griffe = ">=0.47" mkdocstrings = ">=0.25" +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "packaging" version = "24.1" @@ -716,6 +779,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.3.3" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pre_commit-3.3.3-py2.py3-none-any.whl", hash = "sha256:10badb65d6a38caff29703362271d7dca483d01da88f9d7e05d0b97171c136cb"}, + {file = "pre_commit-3.3.3.tar.gz", hash = "sha256:a2256f489cd913d575c145132ae196fe335da32d91a8294b7afe6622335dd023"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prettytable" version = "3.10.2" @@ -1108,6 +1189,26 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.26.3" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, + {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "watchdog" version = "4.0.1" @@ -1195,4 +1296,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "e5687bd8f2fbe096e8eba48515310451662d0fd23db6cbb84df14e1d5f90f978" +content-hash = "9fde9a932fca40538936262263439debec9162feea75797fc973fe6a92a770b1" diff --git a/pyproject.toml b/pyproject.toml index f71f55f..7ff27ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ pyspark = ">3.0.0" findspark = "1.4.2" pytest-describe = "^2.1.0" pytest-cov = "^5.0.0" +pre-commit = "3.3.3" [tool.poetry.group.mkdocs.dependencies] mkdocs = "^1.6.0" @@ -57,3 +58,30 @@ optional = true [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +target-version = "py39" +line-length = 120 +fix = true + +[tool.ruff.format] +preview = true + +[tool.ruff.lint] +ignore = [ + "E501", # Line too long + "E712", # Avoid equality comparisons to `False`; use `if not ...:` for false checks + "F401", # imported but unused; + "F403", # `from X import *` used; unable to detect undefined names + "F405", # X may be undefined, or defined from star imports + "F841", # Local variable `e_info` is assigned to but never used +] + +[tool.ruff.lint.flake8-type-checking] +strict = true + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101", "S603"] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 3062077..fbf7f16 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from chispa import Chispa + @dataclass class MyFormats: mismatched_rows = ["light_yellow"] @@ -9,6 +10,7 @@ class MyFormats: mismatched_cells = ["purple"] matched_cells = ["blue"] + @pytest.fixture() def my_formats(): return MyFormats() diff --git a/tests/spark.py b/tests/spark.py index 475955f..c3bd294 100644 --- a/tests/spark.py +++ b/tests/spark.py @@ -1,7 +1,3 @@ from pyspark.sql import SparkSession -spark = SparkSession.builder \ - .master("local") \ - .appName("chispa") \ - .getOrCreate() - +spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() diff --git a/tests/test_column_comparer.py b/tests/test_column_comparer.py index dea419f..a18d75d 100644 --- a/tests/test_column_comparer.py +++ b/tests/test_column_comparer.py @@ -3,6 +3,7 @@ from .spark import spark from chispa import * + def describe_assert_column_equality(): def it_throws_error_with_data_mismatch(): data = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] @@ -15,7 +16,6 @@ def it_doesnt_throw_without_mismatch(): df = spark.createDataFrame(data, ["name", "expected_name"]) assert_column_equality(df, "name", "expected_name") - def it_works_with_integer_values(): data = [(1, 1), (10, 10), (8, 8), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) @@ -24,28 +24,24 @@ def it_works_with_integer_values(): def describe_assert_approx_column_equality(): def it_works_with_no_mismatches(): - data = [(1.1, 1.1), (1.0004, 1.0005), (.4, .45), (None, None)] + data = [(1.1, 1.1), (1.0004, 1.0005), (0.4, 0.45), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) assert_approx_column_equality(df, "num1", "num2", 0.1) - def it_throws_when_difference_is_bigger_than_precision(): - data = [(1.5, 1.1), (1.0004, 1.0005), (.4, .45)] + data = [(1.5, 1.1), (1.0004, 1.0005), (0.4, 0.45)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: assert_approx_column_equality(df, "num1", "num2", 0.1) - def it_throws_when_comparing_floats_with_none(): data = [(1.1, 1.1), (2.2, 2.2), (3.3, None)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: assert_approx_column_equality(df, "num1", "num2", 0.1) - def it_throws_when_comparing_none_with_floats(): data = [(1.1, 1.1), (2.2, 2.2), (None, 3.3)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: assert_approx_column_equality(df, "num1", "num2", 0.1) - diff --git a/tests/test_dataframe_comparer.py b/tests/test_dataframe_comparer.py index bb90765..9a7ec7e 100644 --- a/tests/test_dataframe_comparer.py +++ b/tests/test_dataframe_comparer.py @@ -17,7 +17,6 @@ def it_throws_with_schema_mismatches(): with pytest.raises(SchemasNotEqualError) as e_info: assert_df_equality(df1, df2) - def it_can_work_with_different_row_orders(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -25,7 +24,6 @@ def it_can_work_with_different_row_orders(): df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) - def it_can_work_with_different_row_orders_with_a_flag(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -33,7 +31,6 @@ def it_can_work_with_different_row_orders_with_a_flag(): df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, ignore_row_order=True) - def it_can_work_with_different_row_and_column_orders(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -41,7 +38,6 @@ def it_can_work_with_different_row_and_column_orders(): df2 = spark.createDataFrame(data2, ["name", "num"]) assert_df_equality(df1, df2, ignore_row_order=True, ignore_column_order=True) - def it_raises_for_row_insensitive_with_diff_content(): data1 = [(1, "XXXX"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -50,7 +46,6 @@ def it_raises_for_row_insensitive_with_diff_content(): with pytest.raises(DataFramesNotEqualError): assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) - def it_throws_with_schema_column_order_mismatch(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) @@ -59,16 +54,12 @@ def it_throws_with_schema_column_order_mismatch(): with pytest.raises(SchemasNotEqualError) as e_info: assert_df_equality(df1, df2) - def it_does_not_throw_on_schema_column_order_mismatch_with_transforms(): data1 = [(1, "jose"), (2, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) data2 = [("jose", 1), ("li", 2)] df2 = spark.createDataFrame(data2, ["name", "num"]) - assert_df_equality(df1, df2, transforms=[ - lambda df: df.select(sorted(df.columns)) - ]) - + assert_df_equality(df1, df2, transforms=[lambda df: df.select(sorted(df.columns))]) def it_throws_with_schema_mismatch(): data1 = [(1, "jose"), (2, "li")] @@ -78,7 +69,6 @@ def it_throws_with_schema_mismatch(): with pytest.raises(SchemasNotEqualError) as e_info: assert_df_equality(df1, df2, transforms=[lambda df: df.select(sorted(df.columns))]) - def it_throws_with_content_mismatches(): data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) @@ -87,7 +77,6 @@ def it_throws_with_content_mismatches(): with pytest.raises(DataFramesNotEqualError) as e_info: assert_df_equality(df1, df2) - def it_throws_with_length_mismatches(): data1 = [("jose", "jose"), ("li", "li"), ("laura", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) @@ -96,64 +85,51 @@ def it_throws_with_length_mismatches(): with pytest.raises(DataFramesNotEqualError) as e_info: assert_df_equality(df1, df2) - def it_can_consider_nan_values_equal(): - data1 = [(float('nan'), "jose"), (2.0, "li")] + data1 = [(float("nan"), "jose"), (2.0, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) - data2 = [(float('nan'), "jose"), (2.0, "li")] + data2 = [(float("nan"), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, allow_nan_equality=True) - def it_does_not_consider_nan_values_equal_by_default(): - data1 = [(float('nan'), "jose"), (2.0, "li")] + data1 = [(float("nan"), "jose"), (2.0, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) - data2 = [(float('nan'), "jose"), (2.0, "li")] + data2 = [(float("nan"), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) with pytest.raises(DataFramesNotEqualError) as e_info: assert_df_equality(df1, df2, allow_nan_equality=False) - def it_can_ignore_metadata(): rows_data = [("jose", 1), ("li", 2), ("luisa", 3)] - schema1 = StructType( - [ - StructField("name", StringType(), True, {"hi": "no"}), - StructField("age", IntegerType(), True), - ] - ) - schema2 = StructType( - [ - StructField("name", StringType(), True, {"hi": "whatever"}), - StructField("age", IntegerType(), True), - ] - ) + schema1 = StructType([ + StructField("name", StringType(), True, {"hi": "no"}), + StructField("age", IntegerType(), True), + ]) + schema2 = StructType([ + StructField("name", StringType(), True, {"hi": "whatever"}), + StructField("age", IntegerType(), True), + ]) df1 = spark.createDataFrame(rows_data, schema1) df2 = spark.createDataFrame(rows_data, schema2) assert_df_equality(df1, df2, ignore_metadata=True) - def it_catches_mismatched_metadata(): rows_data = [("jose", 1), ("li", 2), ("luisa", 3)] - schema1 = StructType( - [ - StructField("name", StringType(), True, {"hi": "no"}), - StructField("age", IntegerType(), True), - ] - ) - schema2 = StructType( - [ - StructField("name", StringType(), True, {"hi": "whatever"}), - StructField("age", IntegerType(), True), - ] - ) + schema1 = StructType([ + StructField("name", StringType(), True, {"hi": "no"}), + StructField("age", IntegerType(), True), + ]) + schema2 = StructType([ + StructField("name", StringType(), True, {"hi": "whatever"}), + StructField("age", IntegerType(), True), + ]) df1 = spark.createDataFrame(rows_data, schema1) df2 = spark.createDataFrame(rows_data, schema2) with pytest.raises(SchemasNotEqualError) as e_info: assert_df_equality(df1, df2) - def describe_are_dfs_equal(): def it_returns_false_with_schema_mismatches(): data1 = [(1, "jose"), (2, "li"), (3, "laura")] @@ -162,7 +138,6 @@ def it_returns_false_with_schema_mismatches(): df2 = spark.createDataFrame(data2, ["name", "expected_name"]) assert are_dfs_equal(df1, df2) == False - def it_returns_false_with_content_mismatches(): data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) @@ -170,7 +145,6 @@ def it_returns_false_with_content_mismatches(): df2 = spark.createDataFrame(data2, ["name", "expected_name"]) assert are_dfs_equal(df1, df2) == False - def it_returns_true_when_dfs_are_same(): data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] df1 = spark.createDataFrame(data1, ["name", "expected_name"]) @@ -188,7 +162,6 @@ def it_throws_with_content_mismatch(): with pytest.raises(DataFramesNotEqualError) as e_info: assert_approx_df_equality(df1, df2, 0.1) - def it_throws_with_with_length_mismatch(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) @@ -197,8 +170,6 @@ def it_throws_with_with_length_mismatch(): with pytest.raises(DataFramesNotEqualError) as e_info: assert_approx_df_equality(df1, df2, 0.1) - - def it_does_not_throw_with_no_mismatch(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) @@ -206,7 +177,6 @@ def it_does_not_throw_with_no_mismatch(): df2 = spark.createDataFrame(data2, ["num", "expected_name"]) assert_approx_df_equality(df1, df2, 0.1) - def it_does_not_throw_with_different_row_col_order(): data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) @@ -214,10 +184,21 @@ def it_does_not_throw_with_different_row_col_order(): df2 = spark.createDataFrame(data2, ["expected_name", "num"]) assert_approx_df_equality(df1, df2, 0.1, ignore_row_order=True, ignore_column_order=True) - def it_does_not_throw_with_nan_values(): - data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None), (float("nan"), "buk")] + data1 = [ + (1.0, "jose"), + (1.1, "li"), + (1.2, "laura"), + (None, None), + (float("nan"), "buk"), + ] df1 = spark.createDataFrame(data1, ["num", "expected_name"]) - data2 = [(1.0, "jose"), (1.05, "li"), (1.2, "laura"), (None, None), (math.nan, "buk")] + data2 = [ + (1.0, "jose"), + (1.05, "li"), + (1.2, "laura"), + (None, None), + (math.nan, "buk"), + ] df2 = spark.createDataFrame(data2, ["num", "expected_name"]) assert_approx_df_equality(df1, df2, 0.1, allow_nan_equality=True) diff --git a/tests/test_readme_examples.py b/tests/test_readme_examples.py index a597ffd..d00f01d 100644 --- a/tests/test_readme_examples.py +++ b/tests/test_readme_examples.py @@ -5,17 +5,14 @@ from chispa.schema_comparer import SchemasNotEqualError from pyspark.sql.types import * +from pyspark.sql import SparkSession + def remove_non_word_characters(col): return F.regexp_replace(col, "[^\\w\\s]+", "") -from pyspark.sql import SparkSession - -spark = (SparkSession.builder - .master("local") - .appName("chispa") - .getOrCreate()) +spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() def describe_column_equality(): @@ -24,116 +21,98 @@ def test_removes_non_word_characters_short(): ("jo&&se", "jose"), ("**li**", "li"), ("#::luisa", "luisa"), - (None, None) + (None, None), ] - df = spark.createDataFrame(data, ["name", "expected_name"])\ - .withColumn("clean_name", remove_non_word_characters(F.col("name"))) + df = spark.createDataFrame(data, ["name", "expected_name"]).withColumn( + "clean_name", remove_non_word_characters(F.col("name")) + ) assert_column_equality(df, "clean_name", "expected_name") - def test_remove_non_word_characters_nice_error(): data = [ ("matt7", "matt"), ("bill&", "bill"), ("isabela*", "isabela"), - (None, None) + (None, None), ] - df = spark.createDataFrame(data, ["name", "expected_name"])\ - .withColumn("clean_name", remove_non_word_characters(F.col("name"))) + df = spark.createDataFrame(data, ["name", "expected_name"]).withColumn( + "clean_name", remove_non_word_characters(F.col("name")) + ) # assert_column_equality(df, "clean_name", "expected_name") with pytest.raises(ColumnsNotEqualError) as e_info: assert_column_equality(df, "clean_name", "expected_name") - def describe_dataframe_equality(): def test_remove_non_word_characters_long(): - source_data = [ - ("jo&&se",), - ("**li**",), - ("#::luisa",), - (None,) - ] + source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)] source_df = spark.createDataFrame(source_data, ["name"]) - actual_df = source_df.withColumn( - "clean_name", - remove_non_word_characters(F.col("name")) - ) + actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) expected_data = [ ("jo&&se", "jose"), ("**li**", "li"), ("#::luisa", "luisa"), - (None, None) + (None, None), ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) assert_df_equality(actual_df, expected_df) - def test_remove_non_word_characters_long_error(): - source_data = [ - ("matt7",), - ("bill&",), - ("isabela*",), - (None,) - ] + source_data = [("matt7",), ("bill&",), ("isabela*",), (None,)] source_df = spark.createDataFrame(source_data, ["name"]) - actual_df = source_df.withColumn( - "clean_name", - remove_non_word_characters(F.col("name")) - ) + actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) expected_data = [ ("matt7", "matt"), ("bill&", "bill"), ("isabela*", "isabela"), - (None, None) + (None, None), ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) # assert_df_equality(actual_df, expected_df) with pytest.raises(DataFramesNotEqualError) as e_info: assert_df_equality(actual_df, expected_df) - def ignore_row_order(): df1 = spark.createDataFrame([(1,), (2,), (3,)], ["some_num"]) df2 = spark.createDataFrame([(2,), (1,), (3,)], ["some_num"]) # assert_df_equality(df1, df2) assert_df_equality(df1, df2, ignore_row_order=True) - def ignore_column_order(): df1 = spark.createDataFrame([(1, 7), (2, 8), (3, 9)], ["num1", "num2"]) df2 = spark.createDataFrame([(7, 1), (8, 2), (9, 3)], ["num2", "num1"]) assert_df_equality(df1, df2, ignore_column_order=True) - def ignore_nullable_property(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) df1 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s1) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + ]) df2 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s2) assert_df_equality(df1, df2, ignore_nullable=True) - def ignore_nullable_property_array(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("coords", ArrayType(DoubleType(), True), True),]) + StructField("name", StringType(), True), + StructField("coords", ArrayType(DoubleType(), True), True), + ]) df1 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s1) s2 = StructType([ - StructField("name", StringType(), True), - StructField("coords", ArrayType(DoubleType(), False), True),]) + StructField("name", StringType(), True), + StructField("coords", ArrayType(DoubleType(), False), True), + ]) df2 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s2) assert_df_equality(df1, df2, ignore_nullable=True) - def consider_nan_values_equal(): - data1 = [(float('nan'), "jose"), (2.0, "li")] + data1 = [(float("nan"), "jose"), (2.0, "li")] df1 = spark.createDataFrame(data1, ["num", "name"]) - data2 = [(float('nan'), "jose"), (2.0, "li")] + data2 = [(float("nan"), "jose"), (2.0, "li")] df2 = spark.createDataFrame(data2, ["num", "name"]) assert_df_equality(df1, df2, allow_nan_equality=True) @@ -152,7 +131,7 @@ def it_prints_underline_message(): ("li", 99), ("rick", 66), ] - df2 = spark.createDataFrame(data, ["firstname", "age"]) + df2 = spark.createDataFrame(data, ["firstname", "age"]) with pytest.raises(DataFramesNotEqualError) as e_info: assert_df_equality(df1, df2, underline_cells=True) @@ -176,62 +155,30 @@ def it_shows_assert_basic_rows_equality(my_formats): with pytest.raises(DataFramesNotEqualError) as e_info: assert_basic_rows_equality(df1.collect(), df2.collect(), underline_cells=True) + def describe_assert_approx_column_equality(): def test_approx_col_equality_same(): - data = [ - (1.1, 1.1), - (2.2, 2.15), - (3.3, 3.37), - (None, None) - ] + data = [(1.1, 1.1), (2.2, 2.15), (3.3, 3.37), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) assert_approx_column_equality(df, "num1", "num2", 0.1) - def test_approx_col_equality_different(): - data = [ - (1.1, 1.1), - (2.2, 2.15), - (3.3, 5.0), - (None, None) - ] + data = [(1.1, 1.1), (2.2, 2.15), (3.3, 5.0), (None, None)] df = spark.createDataFrame(data, ["num1", "num2"]) with pytest.raises(ColumnsNotEqualError) as e_info: assert_approx_column_equality(df, "num1", "num2", 0.1) - def test_approx_df_equality_same(): - data1 = [ - (1.1, "a"), - (2.2, "b"), - (3.3, "c"), - (None, None) - ] + data1 = [(1.1, "a"), (2.2, "b"), (3.3, "c"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "letter"]) - data2 = [ - (1.05, "a"), - (2.13, "b"), - (3.3, "c"), - (None, None) - ] + data2 = [(1.05, "a"), (2.13, "b"), (3.3, "c"), (None, None)] df2 = spark.createDataFrame(data2, ["num", "letter"]) assert_approx_df_equality(df1, df2, 0.1) - def test_approx_df_equality_different(): - data1 = [ - (1.1, "a"), - (2.2, "b"), - (3.3, "c"), - (None, None) - ] + data1 = [(1.1, "a"), (2.2, "b"), (3.3, "c"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "letter"]) - data2 = [ - (1.1, "a"), - (5.0, "b"), - (3.3, "z"), - (None, None) - ] + data2 = [(1.1, "a"), (5.0, "b"), (3.3, "z"), (None, None)] df2 = spark.createDataFrame(data2, ["num", "letter"]) # assert_approx_df_equality(df1, df2, 0.1) with pytest.raises(DataFramesNotEqualError) as e_info: @@ -240,43 +187,25 @@ def test_approx_df_equality_different(): def describe_schema_mismatch_messages(): def test_schema_mismatch_message(): - data1 = [ - (1, "a"), - (2, "b"), - (3, "c"), - (None, None) - ] + data1 = [(1, "a"), (2, "b"), (3, "c"), (None, None)] df1 = spark.createDataFrame(data1, ["num", "letter"]) - data2 = [ - (1, 6), - (2, 7), - (3, 8), - (None, None) - ] + data2 = [(1, 6), (2, 7), (3, 8), (None, None)] df2 = spark.createDataFrame(data2, ["num", "num2"]) with pytest.raises(SchemasNotEqualError) as e_info: assert_df_equality(df1, df2) def test_remove_non_word_characters_long_error(my_chispa): - source_data = [ - ("matt7",), - ("bill&",), - ("isabela*",), - (None,) - ] + source_data = [("matt7",), ("bill&",), ("isabela*",), (None,)] source_df = spark.createDataFrame(source_data, ["name"]) - actual_df = source_df.withColumn( - "clean_name", - remove_non_word_characters(F.col("name")) - ) + actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) expected_data = [ ("matt7", "matt"), ("bill&", "bill"), ("isabela*", "isabela"), - (None, None) + (None, None), ] expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) # my_chispa.assert_df_equality(actual_df, expected_df) with pytest.raises(DataFramesNotEqualError) as e_info: - assert_df_equality(actual_df, expected_df) \ No newline at end of file + assert_df_equality(actual_df, expected_df) diff --git a/tests/test_row_comparer.py b/tests/test_row_comparer.py index 08b5da6..0bb52db 100644 --- a/tests/test_row_comparer.py +++ b/tests/test_row_comparer.py @@ -1,5 +1,3 @@ -import pytest - from .spark import * from chispa.row_comparer import * from pyspark.sql import Row @@ -10,18 +8,19 @@ def test_are_rows_equal(): assert are_rows_equal(Row("luisa", "laura"), Row("luisa", "laura")) == True assert are_rows_equal(Row(None, None), Row(None, None)) == True + def test_are_rows_equal_enhanced(): - assert are_rows_equal_enhanced(Row(n1 = "bob", n2 = "jose"), Row(n1 = "li", n2 = "li"), False) == False - assert are_rows_equal_enhanced(Row(n1 = "luisa", n2 = "laura"), Row(n1 = "luisa", n2 = "laura"), False) == True - assert are_rows_equal_enhanced(Row(n1 = None, n2 = None), Row(n1 = None, n2 = None), False) == True + assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), False) == False + assert are_rows_equal_enhanced(Row(n1="luisa", n2="laura"), Row(n1="luisa", n2="laura"), False) == True + assert are_rows_equal_enhanced(Row(n1=None, n2=None), Row(n1=None, n2=None), False) == True assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), True) == False - assert are_rows_equal_enhanced(Row(n1=float('nan'), n2="jose"), Row(n1=float('nan'), n2="jose"), True) == True - assert are_rows_equal_enhanced(Row(n1=float('nan'), n2="jose"), Row(n1="hi", n2="jose"), True) == False + assert are_rows_equal_enhanced(Row(n1=float("nan"), n2="jose"), Row(n1=float("nan"), n2="jose"), True) == True + assert are_rows_equal_enhanced(Row(n1=float("nan"), n2="jose"), Row(n1="hi", n2="jose"), True) == False def test_are_rows_approx_equal(): - assert are_rows_approx_equal(Row(num = 1.1, first_name = "li"), Row(num = 1.05, first_name = "li"), 0.1) == True - assert are_rows_approx_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.0, first_name = "laura"), 0.1) == True - assert are_rows_approx_equal(Row(num = 5.0, first_name = "laura"), Row(num = 5.9, first_name = "laura"), 0.1) == False - assert are_rows_approx_equal(Row(num = None, first_name = None), Row(num = None, first_name = None), 0.1) == True + assert are_rows_approx_equal(Row(num=1.1, first_name="li"), Row(num=1.05, first_name="li"), 0.1) == True + assert are_rows_approx_equal(Row(num=5.0, first_name="laura"), Row(num=5.0, first_name="laura"), 0.1) == True + assert are_rows_approx_equal(Row(num=5.0, first_name="laura"), Row(num=5.9, first_name="laura"), 0.1) == False + assert are_rows_approx_equal(Row(num=None, first_name=None), Row(num=None, first_name=None), 0.1) == True diff --git a/tests/test_rows_comparer.py b/tests/test_rows_comparer.py index 73ee4f7..ed54dc4 100644 --- a/tests/test_rows_comparer.py +++ b/tests/test_rows_comparer.py @@ -4,7 +4,6 @@ from chispa import * from chispa.rows_comparer import assert_basic_rows_equality from chispa import DataFramesNotEqualError -import math def describe_assert_basic_rows_equality(): @@ -30,4 +29,3 @@ def it_works_when_rows_are_the_same(): data2 = [(1, "jose"), (2, "li"), (3, "laura")] df2 = spark.createDataFrame(data2, ["name", "expected_name"]) assert_basic_rows_equality(df1.collect(), df2.collect()) - diff --git a/tests/test_schema_comparer.py b/tests/test_schema_comparer.py index f679fb0..576c778 100644 --- a/tests/test_schema_comparer.py +++ b/tests/test_schema_comparer.py @@ -7,33 +7,37 @@ def describe_assert_schema_equality(): def it_does_nothing_when_equal(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) assert_schema_equality(s1, s2) - def it_throws_when_column_names_differ(): s1 = StructType([ - StructField("HAHA", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("HAHA", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) with pytest.raises(SchemasNotEqualError) as e_info: assert_schema_equality(s1, s2) - def it_throws_when_schema_lengths_differ(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("fav_number", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("fav_number", IntegerType(), True), + ]) with pytest.raises(SchemasNotEqualError) as e_info: assert_schema_equality(s1, s2) @@ -41,53 +45,55 @@ def it_throws_when_schema_lengths_differ(): def describe_assert_schema_equality_ignore_nullable(): def it_has_good_error_messages_for_different_sized_schemas(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), False), - StructField("age", IntegerType(), True), - StructField("something", IntegerType(), True), - StructField("else", IntegerType(), True) + StructField("name", StringType(), False), + StructField("age", IntegerType(), True), + StructField("something", IntegerType(), True), + StructField("else", IntegerType(), True), ]) with pytest.raises(SchemasNotEqualError) as e_info: assert_schema_equality_ignore_nullable(s1, s2) - def it_does_nothing_when_equal(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) assert_schema_equality_ignore_nullable(s1, s2) - def it_does_nothing_when_only_nullable_flag_is_different(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + ]) assert_schema_equality_ignore_nullable(s1, s2) def describe_are_schemas_equal_ignore_nullable(): def it_returns_true_when_only_nullable_flag_is_different(): s1 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("coords", ArrayType(DoubleType(), True), True), + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("coords", ArrayType(DoubleType(), True), True), ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False), - StructField("coords", ArrayType(DoubleType(), True), False), + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + StructField("coords", ArrayType(DoubleType(), True), False), ]) assert are_schemas_equal_ignore_nullable(s1, s2) == True - def it_returns_true_when_only_nullable_flag_is_different_within_array_element(): s1 = StructType([StructField("coords", ArrayType(DoubleType(), True), True)]) s2 = StructType([StructField("coords", ArrayType(DoubleType(), False), True)]) @@ -98,29 +104,31 @@ def it_returns_true_when_only_nullable_flag_is_different_within_nested_array_ele s2 = StructType([StructField("coords", ArrayType(ArrayType(DoubleType(), False), True), True)]) assert are_schemas_equal_ignore_nullable(s1, s2) == True - def it_returns_false_when_the_element_type_is_different_within_array(): s1 = StructType([StructField("coords", ArrayType(DoubleType(), True), True)]) s2 = StructType([StructField("coords", ArrayType(IntegerType(), True), True)]) assert are_schemas_equal_ignore_nullable(s1, s2) == False - def it_returns_false_when_column_names_differ(): s1 = StructType([ - StructField("blah", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("blah", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("name", StringType(), True), - StructField("age", IntegerType(), False)]) + StructField("name", StringType(), True), + StructField("age", IntegerType(), False), + ]) assert are_schemas_equal_ignore_nullable(s1, s2) == False def it_returns_false_when_columns_have_different_order(): s1 = StructType([ - StructField("blah", StringType(), True), - StructField("age", IntegerType(), True)]) + StructField("blah", StringType(), True), + StructField("age", IntegerType(), True), + ]) s2 = StructType([ - StructField("age", IntegerType(), False), - StructField("blah", StringType(), True)]) + StructField("age", IntegerType(), False), + StructField("blah", StringType(), True), + ]) assert are_schemas_equal_ignore_nullable(s1, s2) == False @@ -130,35 +138,31 @@ def it_returns_true_when_only_nullable_flag_is_different_within_array_element(): s2 = StructField("coords", ArrayType(DoubleType(), False), True) assert are_structfields_equal(s1, s2, True) == True - def it_returns_false_when_the_element_type_is_different_within_array(): s1 = StructField("coords", ArrayType(DoubleType(), True), True) s2 = StructField("coords", ArrayType(IntegerType(), True), True) assert are_structfields_equal(s1, s2, True) == False - def it_returns_true_when_the_element_type_is_same_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) assert are_structfields_equal(s1, s2, True) == True - def it_returns_false_when_the_element_type_is_different_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("hello", IntegerType(), True)]), True) assert are_structfields_equal(s1, s2, True) == False - def it_returns_false_when_the_element_name_is_different_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("world", DoubleType(), True)]), True) assert are_structfields_equal(s1, s2, True) == False - - + def it_returns_true_when_different_nullability_within_struct(): s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) s2 = StructField("coords", StructType([StructField("hello", DoubleType(), False)]), True) assert are_structfields_equal(s1, s2, True) == True + def it_returns_false_when_metadata_differs(): s1 = StructField("coords", StringType(), True, {"hi": "whatever"}) s2 = StructField("coords", StringType(), True, {"hi": "no"}) diff --git a/tests/test_structfield_comparer.py b/tests/test_structfield_comparer.py index df40aea..d73fc34 100644 --- a/tests/test_structfield_comparer.py +++ b/tests/test_structfield_comparer.py @@ -1,5 +1,3 @@ -import pytest - from chispa.structfield_comparer import are_structfields_equal from pyspark.sql.types import * @@ -53,4 +51,4 @@ def it_returns_false_when_nested_types_are_different_with_ignore_nullable_true() def it_returns_true_when_nested_types_have_different_nullability_with_ignore_null_true(): sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) sf2 = StructField("hi", StructType([StructField("world", IntegerType(), True)]), False) - assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == True \ No newline at end of file + assert are_structfields_equal(sf1, sf2, ignore_nullability=True) == True diff --git a/tests/test_terminal_str_formatter.py b/tests/test_terminal_str_formatter.py index 7f73d4d..b387524 100644 --- a/tests/test_terminal_str_formatter.py +++ b/tests/test_terminal_str_formatter.py @@ -1,5 +1,3 @@ -import pytest - from chispa.terminal_str_formatter import format_string