Skip to content

Commit

Permalink
move dict parsing logic to format class
Browse files Browse the repository at this point in the history
move dict parsing logic to format class

inherit Enum from str

rename string formatter

simplify
  • Loading branch information
fpgmaas committed Jul 19, 2024
1 parent 6a3af47 commit 3a1d608
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 72 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,6 @@ assert_df_equality(df1, df2, allow_nan_equality=True)

## Customize formatting

*Available in chispa 0.10+*.

You can specify custom formats for the printed error messages as follows:

```python
Expand All @@ -328,6 +326,21 @@ formats = FormattingConfig(
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats)
```

or similarly:

```python
from chispa.formatting import FormattingConfig, Color, Style

formats = FormattingConfig(
mismatched_rows={"color": Color.LIGHT_YELLOW},
matched_rows={"color": Color.CYAN, "style": Style.BOLD},
mismatched_cells={"color": Color.PURPLE},
matched_cells={"color": Color.BLUE},
)

assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats)
```

You can also define these formats in `conftest.py` and inject them via a fixture:

```python
Expand Down
6 changes: 3 additions & 3 deletions chispa/default_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
@dataclass
class DefaultFormats:
"""
This class is now deprecated. For backwards compatibility, when it's used, it will try to match the
FormattingConfig class.
This class is now deprecated and should be removed in a future release, together with `convert_to_formatting_config`.
"""

mismatched_rows: list[str] = field(default_factory=lambda: ["red"])
Expand All @@ -27,7 +26,8 @@ def __post_init__(self):

def convert_to_formatting_config(instance: Any) -> FormattingConfig:
"""
Converts an instance with specified fields to a FormattingConfig instance.
Converts an instance of an arbitrary class with specified fields to a FormattingConfig instance.
This class is purely for backwards compatibility and should be removed in a future release.
"""

if type(instance) is not DefaultFormats:
Expand Down
4 changes: 2 additions & 2 deletions chispa/formatting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from chispa.formatting.formats import RESET, Color, Format, FormattingConfig, Style
from chispa.formatting.terminal_string_formatter import format_terminal_string
from chispa.formatting.terminal_string_formatter import format_string

__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_terminal_string", "RESET")
__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_string", "RESET")
81 changes: 43 additions & 38 deletions chispa/formatting/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
RESET = "\033[0m"


class Color(Enum):
class Color(str, Enum):
BLACK = "\033[30m"
RED = "\033[31m"
GREEN = "\033[32m"
Expand All @@ -26,7 +26,7 @@ class Color(Enum):
WHITE = "\033[97m"


class Style(Enum):
class Style(str, Enum):
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
BLINK = "\033[5m"
Expand All @@ -39,6 +39,46 @@ class Format:
color: Color | None = None
style: list[Style] | None = None

@classmethod
def from_dict(cls, format_dict: dict) -> Format:
if not isinstance(format_dict, dict):
raise ValueError("Input must be a dictionary")

color = cls._get_color_enum(format_dict.get("color"))
style = format_dict.get("style")
if isinstance(style, str):
styles = [cls._get_style_enum(style)]
elif isinstance(style, list):
styles = [cls._get_style_enum(s) for s in style]
else:
styles = None

return cls(color=color, style=styles)

@staticmethod
def _get_color_enum(color: Color | str | None) -> Color | None:
if isinstance(color, Color):
return color
elif isinstance(color, str):
try:
return Color[color.upper()]
except KeyError:
valid_colors = [c.name.lower() for c in Color]
raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}")
return None

@staticmethod
def _get_style_enum(style: Style | str | None) -> Style | None:
if isinstance(style, Style):
return style
elif isinstance(style, str):
try:
return Style[style.upper()]
except KeyError:
valid_styles = [f.name.lower() for f in Style]
raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}")
return None


class FormattingConfig:
"""
Expand Down Expand Up @@ -83,40 +123,5 @@ def _parse_format(self, format: Format | dict) -> Format:
if isinstance(format, Format):
return format
elif isinstance(format, dict):
invalid_keys = set(format.keys()) - self.VALID_KEYS
if invalid_keys:
raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {self.VALID_KEYS}")

color = self._get_color_enum(format.get("color"))
style = format.get("style")
if isinstance(style, str):
styles = [self._get_style_enum(style)]
elif isinstance(style, list):
styles = [self._get_style_enum(s) for s in style]
else:
styles = None

return Format(color=color, style=styles)
return Format.from_dict(format)
raise ValueError("Invalid format type. Must be Format or dict.")

def _get_color_enum(self, color: Color | str | None) -> Color | None:
if isinstance(color, Color):
return color
elif isinstance(color, str):
try:
return Color[color.upper()]
except KeyError:
valid_colors = [c.name.lower() for c in Color]
raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}")
return None

def _get_style_enum(self, style: Style | str | None) -> Style | None:
if isinstance(style, Style):
return style
elif isinstance(style, str):
try:
return Style[style.upper()]
except KeyError:
valid_styles = [f.name.lower() for f in Style]
raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}")
return None
6 changes: 3 additions & 3 deletions chispa/formatting/terminal_string_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from chispa.formatting.formats import RESET, Format


def format_terminal_string(input_string: str, format: Format) -> str:
def format_string(input_string: str, format: Format) -> str:
if not format.color and not format.style:
return input_string

Expand All @@ -12,10 +12,10 @@ def format_terminal_string(input_string: str, format: Format) -> str:

if format.style:
for style in format.style:
codes.append(style.value)
codes.append(style)

if format.color:
codes.append(format.color.value)
codes.append(format.color)

formatted_string = "".join(codes) + formatted_string + RESET
return formatted_string
30 changes: 15 additions & 15 deletions chispa/rows_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import chispa
from chispa.default_formats import convert_to_formatting_config
from chispa.formatting import FormattingConfig, format_terminal_string
from chispa.formatting import FormattingConfig, format_string


def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: FormattingConfig | None = None):
Expand All @@ -22,10 +22,10 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: For

for r1, r2 in zipped:
if r1 is None and r2 is not None:
t.add_row([None, format_terminal_string(str(r2), formats.mismatched_rows)])
t.add_row([None, format_string(str(r2), formats.mismatched_rows)])
all_rows_equal = False
elif r1 is not None and r2 is None:
t.add_row([format_terminal_string(str(r1), formats.mismatched_rows), None])
t.add_row([format_string(str(r1), formats.mismatched_rows), None])
all_rows_equal = False
else:
r_zipped = list(zip_longest(r1.__fields__, r2.__fields__))
Expand All @@ -34,11 +34,11 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: For
for r1_field, r2_field in r_zipped:
if r1[r1_field] != r2[r2_field]:
all_rows_equal = False
r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells))
r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells))
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells))
else:
r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells))
r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells))
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells))
r1_res = ", ".join(r1_string)
r2_res = ", ".join(r2_string)

Expand Down Expand Up @@ -70,16 +70,16 @@ def assert_generic_rows_equality(
if (r1 is None) ^ (r2 is None):
all_rows_equal = False
t.add_row([
format_terminal_string(str(r1), formats.mismatched_rows),
format_terminal_string(str(r2), formats.mismatched_rows),
format_string(str(r1), formats.mismatched_rows),
format_string(str(r2), formats.mismatched_rows),
])
# rows are equal
elif row_equality_fun(r1, r2, *row_equality_fun_args):
r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__))
r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__))
t.add_row([
format_terminal_string(r1_string, formats.matched_rows),
format_terminal_string(r2_string, formats.matched_rows),
format_string(r1_string, formats.matched_rows),
format_string(r2_string, formats.matched_rows),
])
# otherwise, rows aren't equal
else:
Expand All @@ -89,11 +89,11 @@ def assert_generic_rows_equality(
for r1_field, r2_field in r_zipped:
if r1[r1_field] != r2[r2_field]:
all_rows_equal = False
r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells))
r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells))
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells))
else:
r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells))
r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells))
r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells))
r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells))
r1_res = ", ".join(r1_string)
r2_res = ", ".join(r2_string)

Expand Down
43 changes: 42 additions & 1 deletion tests/formatting/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re

from chispa.formatting.formats import Color, FormattingConfig, Style
from chispa.formatting.formats import Color, Format, FormattingConfig, Style


def test_default_mismatched_rows():
Expand Down Expand Up @@ -82,3 +82,44 @@ def test_invalid_key():
r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}",
error_message,
)


def test_format_from_dict_valid():
format_dict = {"color": "blue", "style": ["bold", "underline"]}
format_instance = Format.from_dict(format_dict)
assert format_instance.color == Color.BLUE
assert format_instance.style == [Style.BOLD, Style.UNDERLINE]


def test_format_from_dict_invalid_color():
format_dict = {"color": "invalid_color", "style": ["bold"]}
try:
Format.from_dict(format_dict)
except ValueError as e:
assert (
str(e)
== "Invalid color name: invalid_color. Valid color names are ['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', 'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', 'light_cyan', 'white']"
)


def test_format_from_dict_invalid_style():
format_dict = {"color": "blue", "style": ["invalid_style"]}
try:
Format.from_dict(format_dict)
except ValueError as e:
assert (
str(e)
== "Invalid style name: invalid_style. Valid style names are ['bold', 'underline', 'blink', 'invert', 'hide']"
)


def test_format_from_dict_invalid_key():
format_dict = {"invalid_key": "value"}
try:
Format.from_dict(format_dict)
except ValueError as e:
error_message = str(e)
assert re.match(
r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}",
error_message,
)
16 changes: 8 additions & 8 deletions tests/formatting/test_terminal_string_formatter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from __future__ import annotations

from chispa.formatting import RESET, format_terminal_string
from chispa.formatting import RESET, format_string
from chispa.formatting.formats import Color, Format, Style


def test_format_with_enum_inputs():
format = Format(color=Color.BLUE, style=[Style.BOLD, Style.UNDERLINE])
formatted_string = format_terminal_string("Hello, World!", format)
expected_string = f"{Style.BOLD.value}{Style.UNDERLINE.value}{Color.BLUE.value}Hello, World!{RESET}"
formatted_string = format_string("Hello, World!", format)
expected_string = f"{Style.BOLD}{Style.UNDERLINE}{Color.BLUE}Hello, World!{RESET}"
assert formatted_string == expected_string


def test_format_with_no_style():
format = Format(color=Color.GREEN, style=[])
formatted_string = format_terminal_string("Hello, World!", format)
expected_string = f"{Color.GREEN.value}Hello, World!{RESET}"
formatted_string = format_string("Hello, World!", format)
expected_string = f"{Color.GREEN}Hello, World!{RESET}"
assert formatted_string == expected_string


def test_format_with_no_color():
format = Format(color=None, style=[Style.BLINK])
formatted_string = format_terminal_string("Hello, World!", format)
expected_string = f"{Style.BLINK.value}Hello, World!{RESET}"
formatted_string = format_string("Hello, World!", format)
expected_string = f"{Style.BLINK}Hello, World!{RESET}"
assert formatted_string == expected_string


def test_format_with_no_color_or_style():
format = Format(color=None, style=[])
formatted_string = format_terminal_string("Hello, World!", format)
formatted_string = format_string("Hello, World!", format)
expected_string = "Hello, World!"
assert formatted_string == expected_string
File renamed without changes.

0 comments on commit 3a1d608

Please sign in to comment.