Skip to content

Commit

Permalink
feat: Introduce FormattingConfig and deprecate DefaultFormats (#127)
Browse files Browse the repository at this point in the history
* Introduce `FormattingConfig` and deprecate `DefaultFormats`
  • Loading branch information
fpgmaas authored Jul 20, 2024
1 parent 67a6079 commit 580631a
Show file tree
Hide file tree
Showing 33 changed files with 679 additions and 88 deletions.
46 changes: 33 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,31 +311,51 @@ assert_df_equality(df1, df2, allow_nan_equality=True)

## Customize formatting

*Available in chispa 0.10+*.

You can specify custom formats for the printed error messages as follows:

```python
@dataclass
class MyFormats:
mismatched_rows = ["light_yellow"]
matched_rows = ["cyan", "bold"]
mismatched_cells = ["purple"]
matched_cells = ["blue"]
from chispa import FormattingConfig

formats = FormattingConfig(
mismatched_rows={"color": "light_yellow"},
matched_rows={"color": "cyan", "style": "bold"},
mismatched_cells={"color": "purple"},
matched_cells={"color": "blue"},
)

assert_basic_rows_equality(df1.collect(), df2.collect(), formats=MyFormats())
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats)
```

or similarly:

```python
from chispa import FormattingConfig, Color, Style

formats = FormattingConfig(
mismatched_rows={"color": Color.LIGHT_YELLOW},
matched_rows={"color": Color.CYAN, "style": Style.BOLD},
mismatched_cells={"color": Color.PURPLE},
matched_cells={"color": Color.BLUE},
)

assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats)
```

You can also define these formats in `conftest.py` and inject them via a fixture:

```python
@pytest.fixture()
def my_formats():
return MyFormats()
def chispa_formats():
return FormattingConfig(
mismatched_rows={"color": "light_yellow"},
matched_rows={"color": "cyan", "style": "bold"},
mismatched_cells={"color": "purple"},
matched_cells={"color": "blue"},
)

def test_shows_assert_basic_rows_equality(my_formats):
def test_shows_assert_basic_rows_equality(chispa_formats):
...
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=my_formats)
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=chispa_formats)
```

![custom_formats](https://github.com/MrPowers/chispa/blob/main/images/custom_formats.png)
Expand Down
19 changes: 16 additions & 3 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import sys
from glob import glob
Expand Down Expand Up @@ -28,6 +30,7 @@
exit(-1)

from chispa.default_formats import DefaultFormats
from chispa.formatting import Color, Format, FormattingConfig, Style

from .column_comparer import (
ColumnsNotEqualError,
Expand All @@ -43,8 +46,14 @@


class Chispa:
def __init__(self, formats=DefaultFormats(), default_output=None):
self.formats = formats
def __init__(self, formats: FormattingConfig | None = None, default_output=None):
if not formats:
self.formats = FormattingConfig()
elif isinstance(formats, FormattingConfig):
self.formats = formats
else:
self.formats = FormattingConfig._from_arbitrary_dataclass(formats)

self.default_outputs = default_output

def assert_df_equality(
Expand Down Expand Up @@ -81,6 +90,10 @@ def assert_df_equality(
"assert_column_equality",
"assert_approx_column_equality",
"assert_basic_rows_equality",
"DefaultFormats",
"Style",
"Color",
"FormattingConfig",
"Format",
"Chispa",
"DefaultFormats",
)
3 changes: 3 additions & 0 deletions chispa/bcolors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


class bcolors:
NC = "\033[0m" # No Color, reset all

Expand Down
2 changes: 2 additions & 0 deletions chispa/column_comparer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from prettytable import PrettyTable

from chispa.bcolors import bcolors
Expand Down
18 changes: 15 additions & 3 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from functools import reduce

from chispa.default_formats import DefaultFormats
from chispa.formatting import FormattingConfig
from chispa.row_comparer import are_rows_approx_equal, are_rows_equal_enhanced
from chispa.rows_comparer import (
assert_basic_rows_equality,
Expand All @@ -25,8 +27,13 @@ def assert_df_equality(
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
formats=DefaultFormats(),
formats: FormattingConfig | None = None,
):
if not formats:
formats = FormattingConfig()
elif not isinstance(formats, FormattingConfig):
formats = FormattingConfig._from_arbitrary_dataclass(formats)

if transforms is None:
transforms = []
if ignore_column_order:
Expand Down Expand Up @@ -71,8 +78,13 @@ def assert_approx_df_equality(
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
formats=DefaultFormats(),
formats: FormattingConfig | None = None,
):
if not formats:
formats = FormattingConfig()
elif not isinstance(formats, FormattingConfig):
formats = FormattingConfig._from_arbitrary_dataclass(formats)

if transforms is None:
transforms = []
if ignore_column_order:
Expand Down
22 changes: 17 additions & 5 deletions chispa/default_formats.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from dataclasses import dataclass
from __future__ import annotations

import warnings
from dataclasses import dataclass, field


@dataclass
class DefaultFormats:
mismatched_rows = ["red"]
matched_rows = ["blue"]
mismatched_cells = ["red", "underline"]
matched_cells = ["blue"]
"""
This class is now deprecated and should be removed in a future release.
"""

mismatched_rows: list[str] = field(default_factory=lambda: ["red"])
matched_rows: list[str] = field(default_factory=lambda: ["blue"])
mismatched_cells: list[str] = field(default_factory=lambda: ["red", "underline"])
matched_cells: list[str] = field(default_factory=lambda: ["blue"])

def __post_init__(self):
warnings.warn(
"DefaultFormats is deprecated. Use `chispa.formatting.FormattingConfig` instead.", DeprecationWarning
)
7 changes: 7 additions & 0 deletions chispa/formatting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

from chispa.formatting.format_string import format_string
from chispa.formatting.formats import RESET, Color, Format, Style
from chispa.formatting.formatting_config import FormattingConfig

__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_string", "RESET")
21 changes: 21 additions & 0 deletions chispa/formatting/format_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from chispa.formatting.formats import RESET, Format


def format_string(input_string: str, format: Format) -> str:
if not format.color and not format.style:
return input_string

formatted_string = input_string
codes = []

if format.style:
for style in format.style:
codes.append(style.value)

if format.color:
codes.append(format.color.value)

formatted_string = "".join(codes) + formatted_string + RESET
return formatted_string
136 changes: 136 additions & 0 deletions chispa/formatting/formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

RESET = "\033[0m"


class Color(str, Enum):
"""
Enum for terminal colors.
Each color is represented by its corresponding ANSI escape code.
"""

BLACK = "\033[30m"
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
PURPLE = "\033[35m"
CYAN = "\033[36m"
LIGHT_GRAY = "\033[37m"
DARK_GRAY = "\033[90m"
LIGHT_RED = "\033[91m"
LIGHT_GREEN = "\033[92m"
LIGHT_YELLOW = "\033[93m"
LIGHT_BLUE = "\033[94m"
LIGHT_PURPLE = "\033[95m"
LIGHT_CYAN = "\033[96m"
WHITE = "\033[97m"


class Style(str, Enum):
"""
Enum for text styles.
Each style is represented by its corresponding ANSI escape code.
"""

BOLD = "\033[1m"
UNDERLINE = "\033[4m"
BLINK = "\033[5m"
INVERT = "\033[7m"
HIDE = "\033[8m"


@dataclass
class Format:
"""
Data class to represent text formatting with color and style.
Attributes:
color (Color | None): The color for the text.
style (list[Style] | None): A list of styles for the text.
"""

color: Color | None = None
style: list[Style] | None = None

@classmethod
def from_dict(cls, format_dict: dict) -> Format:
"""
Create a Format instance from a dictionary.
Args:
format_dict (dict): A dictionary with keys 'color' and/or 'style'.
"""
if not isinstance(format_dict, dict):
raise ValueError("Input must be a dictionary")

valid_keys = {"color", "style"}
invalid_keys = set(format_dict) - valid_keys
if invalid_keys:
raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {valid_keys}")

color = cls._get_color_enum(format_dict.get("color"))
style = format_dict.get("style")
if isinstance(style, str):
styles = [cls._get_style_enum(style)]
elif isinstance(style, list):
styles = [cls._get_style_enum(s) for s in style]
else:
styles = None

return cls(color=color, style=styles)

@classmethod
def from_list(cls, values: list[str]) -> Format:
"""
Create a Format instance from a list of strings.
Args:
values (list[str]): A list of strings representing colors and styles.
"""
if not all(isinstance(value, str) for value in values):
raise ValueError("All elements in the list must be strings")

color = None
styles = []
valid_colors = [c.name.lower() for c in Color]
valid_styles = [s.name.lower() for s in Style]

for value in values:
if value in valid_colors:
color = Color[value.upper()]
elif value in valid_styles:
styles.append(Style[value.upper()])
else:
raise ValueError(
f"Invalid value: {value}. Valid values are colors: {valid_colors} and styles: {valid_styles}"
)

return cls(color=color, style=styles if styles else None)

@staticmethod
def _get_color_enum(color: Color | str | None) -> Color | None:
if isinstance(color, Color):
return color
elif isinstance(color, str):
try:
return Color[color.upper()]
except KeyError:
valid_colors = [c.name.lower() for c in Color]
raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}")
return None

@staticmethod
def _get_style_enum(style: Style | str | None) -> Style | None:
if isinstance(style, Style):
return style
elif isinstance(style, str):
try:
return Style[style.upper()]
except KeyError:
valid_styles = [f.name.lower() for f in Style]
raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}")
return None
Loading

0 comments on commit 580631a

Please sign in to comment.