Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-commit hooks and ruff #106

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ jobs:
- name: Check lock file
run: poetry lock --check

- name: Run pre-commit hooks
run: poetry run pre-commit run -a

test:
runs-on: ubuntu-latest
strategy:
Expand Down
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.5.2"
hooks:
- id: ruff
args: [--exit-non-zero-on-fix]
- id: ruff-format
56 changes: 38 additions & 18 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,62 @@
# We need to add PySpark, try use findspark, or failback to the "manually" find it
try:
import findspark

findspark.init()
except ImportError:
try:
spark_home = os.environ['SPARK_HOME']
sys.path.append(os.path.join(spark_home, 'python'))
py4j_src_zip = glob(os.path.join(spark_home, 'python', 'lib', 'py4j-*-src.zip'))
spark_home = os.environ["SPARK_HOME"]
sys.path.append(os.path.join(spark_home, "python"))
py4j_src_zip = glob(os.path.join(spark_home, "python", "lib", "py4j-*-src.zip"))
if len(py4j_src_zip) == 0:
raise ValueError('py4j source archive not found in %s'
% os.path.join(spark_home, 'python', 'lib'))
raise ValueError("py4j source archive not found in %s" % os.path.join(spark_home, "python", "lib"))
else:
py4j_src_zip = sorted(py4j_src_zip)[::-1]
sys.path.append(py4j_src_zip[0])
except KeyError:
print("Can't find Apache Spark. Please set environment variable SPARK_HOME to root of installation!")
exit(-1)

from .dataframe_comparer import DataFramesNotEqualError, assert_df_equality, assert_approx_df_equality
from .column_comparer import ColumnsNotEqualError, assert_column_equality, assert_approx_column_equality
from .dataframe_comparer import (
DataFramesNotEqualError,
assert_df_equality,
assert_approx_df_equality,
)
from .column_comparer import (
ColumnsNotEqualError,
assert_column_equality,
assert_approx_column_equality,
)
from .rows_comparer import assert_basic_rows_equality
from chispa.default_formats import DefaultFormats

class Chispa():

class Chispa:
def __init__(self, formats=DefaultFormats(), default_output=None):
self.formats = formats
self.default_outputs = default_output

def assert_df_equality(self, df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False):
def assert_df_equality(
self,
df1,
df2,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
):
return assert_df_equality(
df1,
df2,
ignore_nullable,
transforms,
df1,
df2,
ignore_nullable,
transforms,
allow_nan_equality,
ignore_column_order,
ignore_row_order,
underline_cells,
ignore_column_order,
ignore_row_order,
underline_cells,
ignore_metadata,
self.formats)
self.formats,
)
48 changes: 24 additions & 24 deletions chispa/bcolors.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
class bcolors:
NC = '\033[0m' # No Color, reset all
NC = "\033[0m" # No Color, reset all

Bold = '\033[1m'
Underlined = '\033[4m'
Blink = '\033[5m'
Inverted = '\033[7m'
Hidden = '\033[8m'
Bold = "\033[1m"
Underlined = "\033[4m"
Blink = "\033[5m"
Inverted = "\033[7m"
Hidden = "\033[8m"

Black = '\033[30m'
Red = '\033[31m'
Green = '\033[32m'
Yellow = '\033[33m'
Blue = '\033[34m'
Purple = '\033[35m'
Cyan = '\033[36m'
LightGray = '\033[37m'
DarkGray = '\033[30m'
LightRed = '\033[31m'
LightGreen = '\033[32m'
LightYellow = '\033[93m'
LightBlue = '\033[34m'
LightPurple = '\033[35m'
LightCyan = '\033[36m'
White = '\033[97m'
Black = "\033[30m"
Red = "\033[31m"
Green = "\033[32m"
Yellow = "\033[33m"
Blue = "\033[34m"
Purple = "\033[35m"
Cyan = "\033[36m"
LightGray = "\033[37m"
DarkGray = "\033[30m"
LightRed = "\033[31m"
LightGreen = "\033[32m"
LightYellow = "\033[93m"
LightBlue = "\033[34m"
LightPurple = "\033[35m"
LightCyan = "\033[36m"
White = "\033[97m"

# Style
Bold = '\033[1m'
Underline = '\033[4m'
Bold = "\033[1m"
Underline = "\033[4m"


def blue(s: str) -> str:
Expand Down
10 changes: 5 additions & 5 deletions chispa/column_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@


class ColumnsNotEqualError(Exception):
"""The columns are not equal"""
pass
"""The columns are not equal"""

pass


def assert_column_equality(df, col_name1, col_name2):
Expand Down Expand Up @@ -35,11 +36,11 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision):
first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed
second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed
# when one is None and the other isn't, they're not equal
if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None):
if (elements[0] is None and elements[1] is not None) or (elements[0] is not None and elements[1] is None):
all_rows_equal = False
t.add_row([str(elements[0]), str(elements[1])])
# when both are None, they're equal
elif elements[0] == None and elements[1] == None:
elif elements[0] is None and elements[1] is None:
t.add_row([first, second])
# when the diff is less than the threshhold, they're approximately equal
elif abs(elements[0] - elements[1]) < precision:
Expand All @@ -50,4 +51,3 @@ def assert_approx_column_equality(df, col_name1, col_name2, precision):
t.add_row([str(elements[0]), str(elements[1])])
if all_rows_equal == False:
raise ColumnsNotEqualError("\n" + t.get_string())

59 changes: 49 additions & 10 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
from chispa.schema_comparer import assert_schema_equality
from chispa.default_formats import DefaultFormats
from chispa.rows_comparer import assert_basic_rows_equality, assert_generic_rows_equality
from chispa.rows_comparer import (
assert_basic_rows_equality,
assert_generic_rows_equality,
)
from chispa.row_comparer import are_rows_equal_enhanced, are_rows_approx_equal
from functools import reduce


class DataFramesNotEqualError(Exception):
"""The DataFrames are not equal"""
pass
"""The DataFrames are not equal"""

pass

def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, underline_cells=False, ignore_metadata=False, formats=DefaultFormats()):

def assert_df_equality(
df1,
df2,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
formats=DefaultFormats(),
):
if transforms is None:
transforms = []
if ignore_column_order:
Expand All @@ -23,10 +37,20 @@ def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_n
assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata)
if allow_nan_equality:
assert_generic_rows_equality(
df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], underline_cells=underline_cells, formats=formats)
df1.collect(),
df2.collect(),
are_rows_equal_enhanced,
[True],
underline_cells=underline_cells,
formats=formats,
)
else:
assert_basic_rows_equality(
df1.collect(), df2.collect(), underline_cells=underline_cells, formats=formats)
df1.collect(),
df2.collect(),
underline_cells=underline_cells,
formats=formats,
)


def are_dfs_equal(df1, df2):
Expand All @@ -37,8 +61,17 @@ def are_dfs_equal(df1, df2):
return True


def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transforms=None, allow_nan_equality=False,
ignore_column_order=False, ignore_row_order=False, formats=DefaultFormats()):
def assert_approx_df_equality(
df1,
df2,
precision,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
formats=DefaultFormats(),
):
if transforms is None:
transforms = []
if ignore_column_order:
Expand All @@ -49,7 +82,13 @@ def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False, transf
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
if precision != 0:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_approx_equal, [precision, allow_nan_equality], formats)
assert_generic_rows_equality(
df1.collect(),
df2.collect(),
are_rows_approx_equal,
[precision, allow_nan_equality],
formats,
)
elif allow_nan_equality:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], formats)
else:
Expand Down
1 change: 1 addition & 0 deletions chispa/default_formats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass


@dataclass
class DefaultFormats:
mismatched_rows = ["red"]
Expand Down
2 changes: 1 addition & 1 deletion chispa/number_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def nan_safe_equality(x, y) -> bool:


def nan_safe_approx_equality(x, y, precision) -> bool:
return (abs(x-y)<=precision) or (isnan(x) and isnan(y))
return (abs(x - y) <= precision) or (isnan(x) and isnan(y))
7 changes: 3 additions & 4 deletions chispa/row_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool:
d2 = r2.asDict()
if allow_nan_equality:
for key in d1.keys() & d2.keys():
if not(nan_safe_equality(d1[key], d2[key])):
if not (nan_safe_equality(d1[key], d2[key])):
return False
return True
else:
Expand All @@ -33,13 +33,12 @@ def are_rows_approx_equal(r1: Row, r2: Row, precision: float, allow_nan_equality
allEqual = True
for key in d1.keys() & d2.keys():
if isinstance(d1[key], float) and isinstance(d2[key], float):
if allow_nan_equality and not(nan_safe_approx_equality(d1[key], d2[key], precision)):
if allow_nan_equality and not (nan_safe_approx_equality(d1[key], d2[key], precision)):
allEqual = False
elif not(allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])):
elif not (allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])):
allEqual = False
elif abs(d1[key] - d2[key]) > precision:
allEqual = False
elif d1[key] != d2[key]:
allEqual = False
return allEqual

21 changes: 16 additions & 5 deletions chispa/rows_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from prettytable import PrettyTable
from chispa.bcolors import *
import chispa
from pyspark.sql.types import Row
from typing import List
from chispa.terminal_str_formatter import format_string
from chispa.default_formats import DefaultFormats

Expand Down Expand Up @@ -41,7 +39,14 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats=Defa
raise chispa.DataFramesNotEqualError("\n" + t.get_string())


def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fun_args, underline_cells=False, formats=DefaultFormats()):
def assert_generic_rows_equality(
rows1,
rows2,
row_equality_fun,
row_equality_fun_args,
underline_cells=False,
formats=DefaultFormats(),
):
df1_rows = rows1
df2_rows = rows2
zipped = list(zip_longest(df1_rows, df2_rows))
Expand All @@ -51,12 +56,18 @@ def assert_generic_rows_equality(rows1, rows2, row_equality_fun, row_equality_fu
# rows are not equal when one is None and the other isn't
if (r1 is not None and r2 is None) or (r2 is not None and r1 is None):
all_rows_equal = False
t.add_row([format_string(r1, formats.mismatched_rows), format_string(r2, formats.mismatched_rows)])
t.add_row([
format_string(r1, formats.mismatched_rows),
format_string(r2, formats.mismatched_rows),
])
# rows are equal
elif row_equality_fun(r1, r2, *row_equality_fun_args):
r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__))
r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__))
t.add_row([format_string(r1_string, formats.matched_rows), format_string(r2_string, formats.matched_rows)])
t.add_row([
format_string(r1_string, formats.matched_rows),
format_string(r2_string, formats.matched_rows),
])
# otherwise, rows aren't equal
else:
r_zipped = list(zip_longest(r1.__fields__, r2.__fields__))
Expand Down
10 changes: 5 additions & 5 deletions chispa/schema_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@


class SchemasNotEqualError(Exception):
"""The schemas are not equal"""
pass
"""The schemas are not equal"""

pass


def assert_schema_equality(s1, s2, ignore_nullable=False, ignore_metadata=False):
Expand Down Expand Up @@ -51,7 +52,6 @@ def assert_basic_schema_equality(s1, s2):
raise SchemasNotEqualError("\n" + t.get_string())



# deprecate this. ignore_nullable should be a flag.
def assert_schema_equality_ignore_nullable(s1, s2):
if not are_schemas_equal_ignore_nullable(s1, s2):
Expand Down Expand Up @@ -101,9 +101,9 @@ def are_datatypes_equal_ignore_nullable(dt1, dt2):
"""
if dt1.typeName() == dt2.typeName():
# Account for array types by inspecting elementType.
if dt1.typeName() == 'array':
if dt1.typeName() == "array":
return are_datatypes_equal_ignore_nullable(dt1.elementType, dt2.elementType)
elif dt1.typeName() == 'struct':
elif dt1.typeName() == "struct":
return are_schemas_equal_ignore_nullable(dt1, dt2)
else:
return True
Expand Down
Loading
Loading