From 3a1d608a82772ae42208e03bfa93accd2d74e402 Mon Sep 17 00:00:00 2001 From: Florian Maas Date: Fri, 19 Jul 2024 06:37:54 +0200 Subject: [PATCH] move dict parsing logic to format class move dict parsing logic to format class inherit Enum from str rename string formatter simplify --- README.md | 17 +++- chispa/default_formats.py | 6 +- chispa/formatting/__init__.py | 4 +- chispa/formatting/formats.py | 81 ++++++++++--------- .../formatting/terminal_string_formatter.py | 6 +- chispa/rows_comparer.py | 30 +++---- tests/formatting/test_formats.py | 43 +++++++++- .../test_terminal_string_formatter.py | 16 ++-- ..._default_formats.py => test_deprecated.py} | 0 9 files changed, 131 insertions(+), 72 deletions(-) rename tests/{test_default_formats.py => test_deprecated.py} (100%) diff --git a/README.md b/README.md index 260193b..7fa8a63 100644 --- a/README.md +++ b/README.md @@ -311,8 +311,6 @@ assert_df_equality(df1, df2, allow_nan_equality=True) ## Customize formatting -*Available in chispa 0.10+*. - You can specify custom formats for the printed error messages as follows: ```python @@ -328,6 +326,21 @@ formats = FormattingConfig( assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) ``` +or similarly: + +```python +from chispa.formatting import FormattingConfig, Color, Style + +formats = FormattingConfig( + mismatched_rows={"color": Color.LIGHT_YELLOW}, + matched_rows={"color": Color.CYAN, "style": Style.BOLD}, + mismatched_cells={"color": Color.PURPLE}, + matched_cells={"color": Color.BLUE}, + ) + +assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) +``` + You can also define these formats in `conftest.py` and inject them via a fixture: ```python diff --git a/chispa/default_formats.py b/chispa/default_formats.py index 1cb9e04..5dde7bf 100644 --- a/chispa/default_formats.py +++ b/chispa/default_formats.py @@ -10,8 +10,7 @@ @dataclass class DefaultFormats: """ - This class is now deprecated. For backwards compatibility, when it's used, it will try to match the - FormattingConfig class. + This class is now deprecated and should be removed in a future release, together with `convert_to_formatting_config`. """ mismatched_rows: list[str] = field(default_factory=lambda: ["red"]) @@ -27,7 +26,8 @@ def __post_init__(self): def convert_to_formatting_config(instance: Any) -> FormattingConfig: """ - Converts an instance with specified fields to a FormattingConfig instance. + Converts an instance of an arbitrary class with specified fields to a FormattingConfig instance. + This class is purely for backwards compatibility and should be removed in a future release. """ if type(instance) is not DefaultFormats: diff --git a/chispa/formatting/__init__.py b/chispa/formatting/__init__.py index f5371d2..645e52a 100644 --- a/chispa/formatting/__init__.py +++ b/chispa/formatting/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations from chispa.formatting.formats import RESET, Color, Format, FormattingConfig, Style -from chispa.formatting.terminal_string_formatter import format_terminal_string +from chispa.formatting.terminal_string_formatter import format_string -__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_terminal_string", "RESET") +__all__ = ("Style", "Color", "FormattingConfig", "Format", "format_string", "RESET") diff --git a/chispa/formatting/formats.py b/chispa/formatting/formats.py index 28ae801..8ab2b17 100644 --- a/chispa/formatting/formats.py +++ b/chispa/formatting/formats.py @@ -7,7 +7,7 @@ RESET = "\033[0m" -class Color(Enum): +class Color(str, Enum): BLACK = "\033[30m" RED = "\033[31m" GREEN = "\033[32m" @@ -26,7 +26,7 @@ class Color(Enum): WHITE = "\033[97m" -class Style(Enum): +class Style(str, Enum): BOLD = "\033[1m" UNDERLINE = "\033[4m" BLINK = "\033[5m" @@ -39,6 +39,46 @@ class Format: color: Color | None = None style: list[Style] | None = None + @classmethod + def from_dict(cls, format_dict: dict) -> Format: + if not isinstance(format_dict, dict): + raise ValueError("Input must be a dictionary") + + color = cls._get_color_enum(format_dict.get("color")) + style = format_dict.get("style") + if isinstance(style, str): + styles = [cls._get_style_enum(style)] + elif isinstance(style, list): + styles = [cls._get_style_enum(s) for s in style] + else: + styles = None + + return cls(color=color, style=styles) + + @staticmethod + def _get_color_enum(color: Color | str | None) -> Color | None: + if isinstance(color, Color): + return color + elif isinstance(color, str): + try: + return Color[color.upper()] + except KeyError: + valid_colors = [c.name.lower() for c in Color] + raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}") + return None + + @staticmethod + def _get_style_enum(style: Style | str | None) -> Style | None: + if isinstance(style, Style): + return style + elif isinstance(style, str): + try: + return Style[style.upper()] + except KeyError: + valid_styles = [f.name.lower() for f in Style] + raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}") + return None + class FormattingConfig: """ @@ -83,40 +123,5 @@ def _parse_format(self, format: Format | dict) -> Format: if isinstance(format, Format): return format elif isinstance(format, dict): - invalid_keys = set(format.keys()) - self.VALID_KEYS - if invalid_keys: - raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {self.VALID_KEYS}") - - color = self._get_color_enum(format.get("color")) - style = format.get("style") - if isinstance(style, str): - styles = [self._get_style_enum(style)] - elif isinstance(style, list): - styles = [self._get_style_enum(s) for s in style] - else: - styles = None - - return Format(color=color, style=styles) + return Format.from_dict(format) raise ValueError("Invalid format type. Must be Format or dict.") - - def _get_color_enum(self, color: Color | str | None) -> Color | None: - if isinstance(color, Color): - return color - elif isinstance(color, str): - try: - return Color[color.upper()] - except KeyError: - valid_colors = [c.name.lower() for c in Color] - raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}") - return None - - def _get_style_enum(self, style: Style | str | None) -> Style | None: - if isinstance(style, Style): - return style - elif isinstance(style, str): - try: - return Style[style.upper()] - except KeyError: - valid_styles = [f.name.lower() for f in Style] - raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}") - return None diff --git a/chispa/formatting/terminal_string_formatter.py b/chispa/formatting/terminal_string_formatter.py index d3402fc..a7495f8 100644 --- a/chispa/formatting/terminal_string_formatter.py +++ b/chispa/formatting/terminal_string_formatter.py @@ -3,7 +3,7 @@ from chispa.formatting.formats import RESET, Format -def format_terminal_string(input_string: str, format: Format) -> str: +def format_string(input_string: str, format: Format) -> str: if not format.color and not format.style: return input_string @@ -12,10 +12,10 @@ def format_terminal_string(input_string: str, format: Format) -> str: if format.style: for style in format.style: - codes.append(style.value) + codes.append(style) if format.color: - codes.append(format.color.value) + codes.append(format.color) formatted_string = "".join(codes) + formatted_string + RESET return formatted_string diff --git a/chispa/rows_comparer.py b/chispa/rows_comparer.py index 583d863..ff2f01c 100644 --- a/chispa/rows_comparer.py +++ b/chispa/rows_comparer.py @@ -6,7 +6,7 @@ import chispa from chispa.default_formats import convert_to_formatting_config -from chispa.formatting import FormattingConfig, format_terminal_string +from chispa.formatting import FormattingConfig, format_string def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: FormattingConfig | None = None): @@ -22,10 +22,10 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: For for r1, r2 in zipped: if r1 is None and r2 is not None: - t.add_row([None, format_terminal_string(str(r2), formats.mismatched_rows)]) + t.add_row([None, format_string(str(r2), formats.mismatched_rows)]) all_rows_equal = False elif r1 is not None and r2 is None: - t.add_row([format_terminal_string(str(r1), formats.mismatched_rows), None]) + t.add_row([format_string(str(r1), formats.mismatched_rows), None]) all_rows_equal = False else: r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) @@ -34,11 +34,11 @@ def assert_basic_rows_equality(rows1, rows2, underline_cells=False, formats: For for r1_field, r2_field in r_zipped: if r1[r1_field] != r2[r2_field]: all_rows_equal = False - r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells)) - r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells)) + r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells)) + r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells)) else: - r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells)) - r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells)) + r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells)) + r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells)) r1_res = ", ".join(r1_string) r2_res = ", ".join(r2_string) @@ -70,16 +70,16 @@ def assert_generic_rows_equality( if (r1 is None) ^ (r2 is None): all_rows_equal = False t.add_row([ - format_terminal_string(str(r1), formats.mismatched_rows), - format_terminal_string(str(r2), formats.mismatched_rows), + format_string(str(r1), formats.mismatched_rows), + format_string(str(r2), formats.mismatched_rows), ]) # rows are equal elif row_equality_fun(r1, r2, *row_equality_fun_args): r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__)) r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__)) t.add_row([ - format_terminal_string(r1_string, formats.matched_rows), - format_terminal_string(r2_string, formats.matched_rows), + format_string(r1_string, formats.matched_rows), + format_string(r2_string, formats.matched_rows), ]) # otherwise, rows aren't equal else: @@ -89,11 +89,11 @@ def assert_generic_rows_equality( for r1_field, r2_field in r_zipped: if r1[r1_field] != r2[r2_field]: all_rows_equal = False - r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells)) - r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells)) + r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells)) + r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells)) else: - r1_string.append(format_terminal_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells)) - r2_string.append(format_terminal_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells)) + r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells)) + r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells)) r1_res = ", ".join(r1_string) r2_res = ", ".join(r2_string) diff --git a/tests/formatting/test_formats.py b/tests/formatting/test_formats.py index ddab92c..bf25132 100644 --- a/tests/formatting/test_formats.py +++ b/tests/formatting/test_formats.py @@ -2,7 +2,7 @@ import re -from chispa.formatting.formats import Color, FormattingConfig, Style +from chispa.formatting.formats import Color, Format, FormattingConfig, Style def test_default_mismatched_rows(): @@ -82,3 +82,44 @@ def test_invalid_key(): r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", error_message, ) + + +def test_format_from_dict_valid(): + format_dict = {"color": "blue", "style": ["bold", "underline"]} + format_instance = Format.from_dict(format_dict) + assert format_instance.color == Color.BLUE + assert format_instance.style == [Style.BOLD, Style.UNDERLINE] + + +def test_format_from_dict_invalid_color(): + format_dict = {"color": "invalid_color", "style": ["bold"]} + try: + Format.from_dict(format_dict) + except ValueError as e: + assert ( + str(e) + == "Invalid color name: invalid_color. Valid color names are ['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', 'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', 'light_cyan', 'white']" + ) + + +def test_format_from_dict_invalid_style(): + format_dict = {"color": "blue", "style": ["invalid_style"]} + try: + Format.from_dict(format_dict) + except ValueError as e: + assert ( + str(e) + == "Invalid style name: invalid_style. Valid style names are ['bold', 'underline', 'blink', 'invert', 'hide']" + ) + + +def test_format_from_dict_invalid_key(): + format_dict = {"invalid_key": "value"} + try: + Format.from_dict(format_dict) + except ValueError as e: + error_message = str(e) + assert re.match( + r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", + error_message, + ) diff --git a/tests/formatting/test_terminal_string_formatter.py b/tests/formatting/test_terminal_string_formatter.py index 495caab..8b6b721 100644 --- a/tests/formatting/test_terminal_string_formatter.py +++ b/tests/formatting/test_terminal_string_formatter.py @@ -1,32 +1,32 @@ from __future__ import annotations -from chispa.formatting import RESET, format_terminal_string +from chispa.formatting import RESET, format_string from chispa.formatting.formats import Color, Format, Style def test_format_with_enum_inputs(): format = Format(color=Color.BLUE, style=[Style.BOLD, Style.UNDERLINE]) - formatted_string = format_terminal_string("Hello, World!", format) - expected_string = f"{Style.BOLD.value}{Style.UNDERLINE.value}{Color.BLUE.value}Hello, World!{RESET}" + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Style.BOLD}{Style.UNDERLINE}{Color.BLUE}Hello, World!{RESET}" assert formatted_string == expected_string def test_format_with_no_style(): format = Format(color=Color.GREEN, style=[]) - formatted_string = format_terminal_string("Hello, World!", format) - expected_string = f"{Color.GREEN.value}Hello, World!{RESET}" + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Color.GREEN}Hello, World!{RESET}" assert formatted_string == expected_string def test_format_with_no_color(): format = Format(color=None, style=[Style.BLINK]) - formatted_string = format_terminal_string("Hello, World!", format) - expected_string = f"{Style.BLINK.value}Hello, World!{RESET}" + formatted_string = format_string("Hello, World!", format) + expected_string = f"{Style.BLINK}Hello, World!{RESET}" assert formatted_string == expected_string def test_format_with_no_color_or_style(): format = Format(color=None, style=[]) - formatted_string = format_terminal_string("Hello, World!", format) + formatted_string = format_string("Hello, World!", format) expected_string = "Hello, World!" assert formatted_string == expected_string diff --git a/tests/test_default_formats.py b/tests/test_deprecated.py similarity index 100% rename from tests/test_default_formats.py rename to tests/test_deprecated.py