From f2df98d2b2ab580bac1d707111e654c9ae152f56 Mon Sep 17 00:00:00 2001 From: Claas Date: Mon, 11 Nov 2024 12:13:14 +0100 Subject: [PATCH] improve typing (work in progress) --- src/dictIO/dict.py | 104 +++++++++++++++++++------------------- src/dictIO/dict_reader.py | 23 +++++---- src/dictIO/formatter.py | 48 +++++++++++------- src/dictIO/parser.py | 14 ++--- src/dictIO/utils/dict.py | 16 +++--- tests/test_cpp_dict.py | 10 ++-- tests/test_dict.py | 38 +++++++------- tests/test_formatter.py | 2 +- tests/test_parser.py | 4 +- 9 files changed, 140 insertions(+), 119 deletions(-) diff --git a/src/dictIO/dict.py b/src/dictIO/dict.py index f3e5eb95..6a97e3c5 100644 --- a/src/dictIO/dict.py +++ b/src/dictIO/dict.py @@ -13,12 +13,13 @@ from copy import copy from pathlib import Path from typing import ( + Any, TypeVar, cast, overload, ) -from dictIO.types import K, TGlobalKey, TKey, TValue, V +from dictIO.types import K, TKey, TValue, V from dictIO.utils.counter import BorgCounter from dictIO.utils.dict import ( find_global_key, @@ -160,7 +161,7 @@ def fromkeys( cls, iterable: Iterable[_K], value: None = None, - ) -> SDict[_K, TValue | None]: + ) -> SDict[_K, Any | None]: pass @overload @@ -177,7 +178,7 @@ def fromkeys( cls, iterable: Iterable[_K], value: _V | None = None, - ) -> SDict[_K, _V] | SDict[_K, TValue | None]: + ) -> SDict[_K, _V] | SDict[_K, Any | None]: """Create a new SDict instance from the keys of an iterable. Parameters @@ -189,7 +190,7 @@ def fromkeys( Returns ------- - SDict[_K, _V] | SDict[_K, TValue | None] + SDict[_K, _V] | SDict[_K, Any | None] The created SDict instance. """ new_dict: SDict[_K, _V] = cast(SDict[_K, _V], cls()) @@ -245,7 +246,7 @@ def load( from dictIO.dict_reader import DictReader - loaded_dict: SDict[TKey, TValue] = DictReader.read( + loaded_dict: SDict[K, V] = DictReader.read( source_file=source_file, ) self.reset() @@ -254,7 +255,7 @@ def load( # Maybe this method needs to be refactored to a factory function, returning a new `SDict` instance # with the actual types of `loaded_dict`. # CLAROS, 2024-11-06 - self.update(cast(SDict[K, V], loaded_dict)) + self.update(loaded_dict) self._set_source_file(source_file) return self @@ -299,10 +300,10 @@ def dump( from dictIO.dict_writer import DictWriter DictWriter.write( - source_dict=cast(SDict[TKey, TValue], self), + source_dict=self, target_file=target_file, ) - self._set_source_file(target_file) + self._set_source_file(source_file=target_file) return target_file @@ -355,15 +356,15 @@ def name(self) -> str: return self._name @property - def variables(self) -> dict[str, TValue]: + def variables(self) -> dict[str, V]: """Returns a dict with all Variables currently registered. Returns ------- - Dict[str, TValue] + Dict[str, V] dict of all Variables currently registered. """ - variables: MutableMapping[str, TValue] = {} + variables: MutableMapping[str, V] = {} def extract_variables_from_dict(dict_in: MutableMapping[_K, V]) -> None: for key, value in dict_in.items(): @@ -379,15 +380,16 @@ def extract_variables_from_dict(dict_in: MutableMapping[_K, V]) -> None: continue if isinstance(value, MutableSequence): # special case: item is a list, but does NOT contain a nested dict (-> e.g. a vector or matrix) - variables[key] = value + variables[key] = cast(V, value) else: # base case: item is a single value type - _value = _insert_expression(value, self) + value = cast(V, value) + _value = _insert_expression(value=value, s_dict=self) if not _value_contains_circular_reference(key, _value): variables[key] = _value return - def extract_variables_from_list(list_in: MutableSequence[TValue]) -> None: + def extract_variables_from_list(list_in: MutableSequence[V]) -> None: # sourcery skip: remove-redundant-pass for value in list_in: if isinstance(value, MutableMapping): @@ -401,7 +403,7 @@ def extract_variables_from_list(list_in: MutableSequence[TValue]) -> None: pass return - def list_contains_dict(list_in: MutableSequence[TValue]) -> bool: + def list_contains_dict(list_in: MutableSequence[V]) -> bool: # sourcery skip: merge-duplicate-blocks, use-any for value in list_in: if isinstance(value, MutableMapping): @@ -462,9 +464,9 @@ class attributes in the update. Parameters ---------- - __m : Mapping[TKey, TValue] + m : Mapping[K, V] | Iterable[tuple[K, V]] | None dict containing the keys to be updated and its new values - **kwargs: TValue + **kwargs: V optional keyword arguments. These will be passed on to the update() method of the parent class. """ if m is None: @@ -526,9 +528,9 @@ def _recursive_merge( Parameters ---------- - target_dict : MutableMapping[TKey, TValue] | MutableMapping[int, TValue] + target_dict : MutableMapping[_K, _V] target dict - dict_to_merge : Mapping[TKey, TValue] | Mapping[int, TValue] + dict_to_merge : Mapping[_K, _V] dict to be merged into target dict overwrite : bool, optional if True, existing keys will be overwritten, by default False @@ -536,18 +538,18 @@ def _recursive_merge( for key in dict_to_merge: if ( key in target_dict - and isinstance(target_dict[key], MutableMapping) # pyright: ignore[reportArgumentType] - and isinstance(dict_to_merge[key], Mapping) # pyright: ignore[reportArgumentType] + and isinstance(target_dict[key], MutableMapping) + and isinstance(dict_to_merge[key], Mapping) ): # dict self._recursive_merge( # Recursion - target_dict=cast(MutableMapping[TKey, TValue], target_dict[key]), - dict_to_merge=cast(Mapping[TKey, TValue], dict_to_merge[key]), + target_dict=cast(MutableMapping[K, V], target_dict[key]), + dict_to_merge=cast(Mapping[K, V], dict_to_merge[key]), overwrite=overwrite, ) else: value_in_target_dict_contains_circular_reference = False if isinstance(target_dict, SDict) and key in target_dict: - value = _insert_expression(target_dict[key], target_dict) + value = _insert_expression(value=target_dict[key], s_dict=target_dict) value_in_target_dict_contains_circular_reference = _value_contains_circular_reference(key, value) if overwrite or key not in target_dict or value_in_target_dict_contains_circular_reference: target_dict[key] = dict_to_merge[key] # Update @@ -794,12 +796,11 @@ def __str__(self) -> str: str the string representation """ - from dictIO import ( - NativeFormatter, # __str__ shall be formatted in dictIO native file format - ) + # __str__ shall be formatted in dictIO native file format + from dictIO import NativeFormatter formatter = NativeFormatter() - return formatter.to_string(cast(SDict[TKey, TValue], self)) + return formatter.to_string(self) def __repr__(self) -> str: """Return a string representation of the SDict instance. @@ -839,7 +840,7 @@ def order_keys(self) -> None: self.includes = order_keys(self.includes) return - def find_global_key(self, query: str = "") -> list[TGlobalKey] | None: + def find_global_key(self, query: str = "") -> list[K | int] | None: """Return the global key thread to the first key the value of which matches the passed in query. Function works recursively on nested dicts and is non-greedy: The key of the first match is returned. @@ -854,12 +855,12 @@ def find_global_key(self, query: str = "") -> list[TGlobalKey] | None: Returns ------- - Union[list[TKey], None] + list[K | int] | None global key thread to the first key the value of which matches the passed in query, if found. Otherwise None. """ - return find_global_key(cast(SDict[TKey, TValue], self), query) + return find_global_key(self, query) - def set_global_key(self, global_key: MutableSequence[TKey], value: TValue) -> None: + def set_global_key(self, global_key: MutableSequence[K | int], value: V) -> None: """Set the value for the passed in global key. The global key thread is traversed downwards until arrival at the target key, @@ -867,25 +868,25 @@ def set_global_key(self, global_key: MutableSequence[TKey], value: TValue) -> No Parameters ---------- - global_key : MutableSequence[TValue] + global_key : MutableSequence[K | int] list of keys defining the global key thread to the target key (such as returned by method find_global_key()) - value : TValue + value : V value the target key shall be set to """ set_global_key( - arg=cast(MutableMapping[TKey, TValue], self), + arg=self, global_key=global_key, value=value, ) return - def global_key_exists(self, global_key: MutableSequence[TKey]) -> bool: + def global_key_exists(self, global_key: MutableSequence[K | int]) -> bool: """Check whether the specified global key exists. Parameters ---------- - global_key : MutableSequence[TValue] + global_key : MutableSequence[K | int] global key the existence of which is checked Returns @@ -897,16 +898,16 @@ def global_key_exists(self, global_key: MutableSequence[TKey]) -> bool: probe the existence of (nested) keys in dict """ return global_key_exists( - dict_in=cast(MutableMapping[TKey, TValue], self), + dict_in=self, global_key=global_key, ) - def reduce_scope(self, scope: MutableSequence[TKey]) -> None: + def reduce_scope(self, scope: MutableSequence[K]) -> None: """Reduces the dict to the keys defined in scope. Parameters ---------- - scope : MutableSequence[str] + scope : MutableSequence[K] scope the dict shall be reduced to """ if scope: @@ -953,7 +954,7 @@ def _clean(self) -> None: Doublettes are identified through equality with their lookup values. """ - def _recursive_clean(data: MutableMapping[TKey, TValue]) -> None: + def _recursive_clean(data: MutableMapping[K, V]) -> None: self._clean_data( data=data, ) @@ -964,11 +965,11 @@ def _recursive_clean(data: MutableMapping[TKey, TValue]) -> None: return - _recursive_clean(data=cast(MutableMapping[TKey, TValue], self)) + _recursive_clean(data=self) return - def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None: + def _clean_data(self, data: MutableMapping[K, V]) -> None: """Find and remove doublettes of PLACEHOLDER keys. Find and remove doublettes of following PLACEHOLDER keys within data: @@ -979,7 +980,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None: Doublettes are identified through equality with their lookup values. """ # IDENTIFY all placeholders on current level - keys_on_this_level: list[TKey] = list(data) + keys_on_this_level: list[K] = list(data) block_comments_on_this_level: list[str] = [] includes_on_this_level: list[str] = [] line_comments_on_this_level: list[str] = [] @@ -1001,7 +1002,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None: if block_comment in unique_block_comments_on_this_level: # Found doublette # Remove from current level in data (the dict) - del data[_block_comment] + del data[cast(K, _block_comment)] # ..AND from self.block_comments (the lookup table) del self.block_comments[_id] else: @@ -1015,7 +1016,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None: if include in unique_includes_on_this_level: # Found doublette # Remove from current level in data (the dict) - del data[_include] + del data[cast(K, _include)] # ..AND from self.includes (the lookup table) del self.includes[_id] else: @@ -1029,7 +1030,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None: if line_comment in unique_line_comments_on_this_level: # Found doublette # Remove from current level in data (the dict) - del data[_line_comment] + del data[cast(K, _line_comment)] # ..AND from self.line_comments (the lookup table) del self.line_comments[_id] else: @@ -1071,15 +1072,16 @@ def data(self, data: dict[K, V]) -> None: return -def _insert_expression(value: TValue, s_dict: SDict[K, V]) -> TValue: +def _insert_expression(value: V, s_dict: SDict[K, V]) -> V: if not isinstance(value, str): return value if not re.search(r"EXPRESSION\d{6}", value): - return value + return cast(V, value) if match_index := re.search(r"\d{6}", value): index = int(match_index[0]) - return s_dict.expressions[index]["expression"] if index in s_dict.expressions else value - return value + _value = s_dict.expressions[index]["expression"] if index in s_dict.expressions else value + return cast(V, _value) + return cast(V, value) def _value_contains_circular_reference(key: TKey, value: TValue) -> bool: diff --git a/src/dictIO/dict_reader.py b/src/dictIO/dict_reader.py index 1220aa25..0d0e945c 100644 --- a/src/dictIO/dict_reader.py +++ b/src/dictIO/dict_reader.py @@ -41,7 +41,7 @@ ) from dictIO import Parser, SDict -from dictIO.types import K, TKey, TValue, V +from dictIO.types import K, M, TKey, TValue, V from dictIO.utils.counter import DejaVue __ALL__ = ["DictReader"] @@ -64,7 +64,7 @@ def read( comments: bool = True, scope: MutableSequence[Any] | None = None, parser: Parser | None = None, - ) -> SDict[TKey, TValue]: + ) -> SDict[Any, Any]: """Read a dictionary file in dictIO native file format, as well as JSON and XML. Reads a dict file, parses it and transforms its content into a dictIO dict object (SDict). @@ -207,10 +207,10 @@ def _merge_includes_recursive(parent_dict: SDict[K, V]) -> SDict[K, V]: @staticmethod def _resolve_reference( reference: str, - variables: MutableMapping[str, TValue], - ) -> TValue: + variables: MutableMapping[str, V], + ) -> V: # resolves a single reference - value: TValue = None + value: V = None try: # extract indices, ugly version, nice version is re.sub with a positive lookahead indexing = re.findall(pattern=r"\[.+\]$", string=reference)[0] @@ -286,7 +286,7 @@ def _eval_expressions(dict_in: SDict[K, V]) -> None: eval_result: V | None = None if "$" not in expression: try: - eval_result = eval(expression) # noqa: S307 + eval_result = cast(V, eval(expression)) # noqa: S307 eval_successful = True except NameError: eval_result = cast(V, expression) @@ -294,6 +294,7 @@ def _eval_expressions(dict_in: SDict[K, V]) -> None: except SyntaxError: logger.warning(f'DictReader.(): evaluation of "{expression}" not yet possible') if eval_successful: + assert eval_result is not None while global_key := dict_in.find_global_key(query=placeholder): # Substitute the placeholder in the dict with the result of the evaluated expression dict_in.set_global_key(global_key, value=eval_result) @@ -337,27 +338,27 @@ def _eval_expressions(dict_in: SDict[K, V]) -> None: expression = item["expression"] while global_key := dict_in.find_global_key(query=placeholder): # Substitute the placeholder with the original (or at least partly resolved) expression - dict_in.set_global_key(global_key, value=expression) + dict_in.set_global_key(global_key, value=cast(V, expression)) dict_in.expressions.clear() return @staticmethod - def _remove_comment_keys(data: MutableMapping[TKey, TValue]) -> MutableMapping[TKey, TValue]: + def _remove_comment_keys(data: M) -> M: """Remove comments from data structure for read function call from other programs.""" remove = "[A-Z]+COMMENT[0-9;]+" with contextlib.suppress(Exception): for key in list(data.keys()): # work on a copy of the keys if isinstance(data[key], MutableMapping): - sub_dict = cast(MutableMapping[TKey, TValue], data[key]) + sub_dict = cast(M, data[key]) data.update({key: DictReader._remove_comment_keys(sub_dict)}) # recursion elif re.search(pattern=remove, string=str(key)): _ = data.pop(key) return data @staticmethod - def _remove_include_keys(data: MutableMapping[TKey, TValue]) -> None: + def _remove_include_keys(data: M) -> M: """Remove includes from data structure for read function call from other programs.""" remove = "INCLUDE[0-9;]+" @@ -365,4 +366,4 @@ def _remove_include_keys(data: MutableMapping[TKey, TValue]) -> None: for key in list(data.keys()): # work on a copy of the keys if type(key) is str and re.search(pattern=remove, string=key): _ = data.pop(key) - return + return data diff --git a/src/dictIO/formatter.py b/src/dictIO/formatter.py index 5c17811a..cf8fa026 100644 --- a/src/dictIO/formatter.py +++ b/src/dictIO/formatter.py @@ -461,7 +461,7 @@ def __init__(self) -> None: def to_string( self, - arg: MutableMapping[TKey, TValue], + arg: MutableMapping[K, V], ) -> str: # sourcery skip: dict-comprehension """Create a string representation of the passed in dict in dictIO native file format. @@ -480,13 +480,13 @@ def to_string( # Sort dict in a way that block comment and include statement come first original_data = deepcopy(_arg) - sorted_data: dict[TKey, TValue] = {} + sorted_data: dict[K, V] = {} for key, element in original_data.items(): if type(key) is str and re.search(r"BLOCKCOMMENT\d{6}", key): - sorted_data[key] = element # noqa: PERF403 + sorted_data[cast(K, key)] = element for key, element in original_data.items(): if type(key) is str and re.search(r"INCLUDE\d{6}", key): - sorted_data[key] = element # noqa: PERF403 + sorted_data[cast(K, key)] = element for key in sorted_data: del original_data[key] sorted_data |= original_data @@ -752,7 +752,7 @@ def format_expression_string(self, arg: str) -> str: def insert_block_comments( self, - s_dict: SDict[TKey, TValue], + s_dict: SDict[K, V], s: str, ) -> str: """Insert back all block comments. @@ -810,7 +810,11 @@ def make_default_block_comment(self, block_comment: str = "") -> str: block_comment = default_block_comment + block_comment return block_comment - def insert_includes(self, s_dict: SDict[TKey, TValue], s: str) -> str: + def insert_includes( + self, + s_dict: SDict[K, V], + s: str, + ) -> str: """Insert back all include directives.""" search_pattern: str | Pattern[str] for key, (_, include_file_name, _) in s_dict.includes.items(): @@ -824,7 +828,11 @@ def insert_includes(self, s_dict: SDict[TKey, TValue], s: str) -> str: return s - def insert_line_comments(self, s_dict: SDict[TKey, TValue], s: str) -> str: + def insert_line_comments( + self, + s_dict: SDict[K, V], + s: str, + ) -> str: """Insert back all line directives.""" search_pattern: str | Pattern[str] for key, line_comment in s_dict.line_comments.items(): @@ -864,7 +872,7 @@ def __init__(self) -> None: def to_string( self, - arg: MutableMapping[TKey, TValue] | SDict[TKey, TValue], + arg: MutableMapping[K, V], ) -> str: """Create a string representation of the passed in dict in OpenFOAM dictionary format. @@ -884,7 +892,7 @@ def to_string( # Remove all dict entries starting with underscore def remove_underscore_keys_recursive( - arg: MutableMapping[TKey, TValue], + arg: MutableMapping[K, V], ) -> None: keys = list(arg.keys()) for key in keys: @@ -1008,7 +1016,7 @@ def __init__(self) -> None: def to_string( self, - arg: MutableMapping[TKey, TValue] | SDict[TKey, TValue], + arg: MutableMapping[K, V], ) -> str: """Create a string representation of the passed in dict in JSON dictionary format. @@ -1040,7 +1048,11 @@ def to_string( return s - def insert_includes(self, s_dict: SDict[TKey, TValue], s: str) -> str: + def insert_includes( + self, + s_dict: SDict[K, V], + s: str, + ) -> str: """Insert back all include directives.""" search_pattern: str | Pattern[str] for key, (_, include_file_name, _) in s_dict.includes.items(): @@ -1120,7 +1132,7 @@ def __init__( def to_string( self, - arg: MutableMapping[TKey, TValue] | SDict[TKey, TValue], + arg: MutableMapping[K, V], ) -> str: """Create a string representation of the passed in dict in XML format. @@ -1143,18 +1155,20 @@ def to_string( # Check whether xml opts are contained in dict. # If so, read and use them if "_xmlOpts" in arg: - xml_opts = cast(MutableMapping[TKey, TValue], arg["_xmlOpts"]) + xml_opts = cast(MutableMapping[K, V], arg[cast(K, "_xmlOpts")]) namespaces = ( - cast(MutableMapping[str, str], xml_opts["_nameSpaces"]) if "_nameSpaces" in xml_opts else namespaces + cast(MutableMapping[str, str], xml_opts[cast(K, "_nameSpaces")]) + if "_nameSpaces" in xml_opts + else namespaces ) - root_tag = str(xml_opts["_rootTag"]) if "_rootTag" in xml_opts else root_tag + root_tag = str(xml_opts[cast(K, "_rootTag")]) if "_rootTag" in xml_opts else root_tag root_attributes = ( - cast(MutableMapping[str, str], xml_opts["_rootAttributes"]) + cast(MutableMapping[str, str], xml_opts[cast(K, "_rootAttributes")]) if "_rootAttributes" in xml_opts else root_attributes ) self.remove_node_numbering = ( - bool(xml_opts["_removeNodeNumbering"]) + bool(xml_opts[cast(K, "_removeNodeNumbering")]) if "_removeNodeNumbering" in xml_opts else self.remove_node_numbering ) diff --git a/src/dictIO/parser.py b/src/dictIO/parser.py index fd3fd8dc..5d5e8a34 100644 --- a/src/dictIO/parser.py +++ b/src/dictIO/parser.py @@ -17,7 +17,7 @@ from lxml.etree import _Element as LxmlElement # pyright: ignore[reportPrivateUsage] from dictIO import SDict -from dictIO.types import K, TKey, TSingleValue, TValue, V +from dictIO.types import K, M, S, TKey, TSingleValue, TValue, V from dictIO.utils.counter import BorgCounter if TYPE_CHECKING: @@ -349,8 +349,8 @@ def remove_quotes_from_string( @staticmethod def remove_quotes_from_strings( - arg: MutableMapping[TKey, TValue] | MutableSequence[TValue], - ) -> MutableMapping[TKey, TValue] | MutableSequence[TValue]: + arg: M | S, + ) -> M | S: """Remove quotes from multiple strings. Removes quotes (single and double quotes) from all string objects inside a dict or list. @@ -360,22 +360,24 @@ def remove_quotes_from_strings( Parameters ---------- - arg : Union[MutableMapping[TKey, TValue], MutableSequence[TValue]] + arg : MutableMapping[K, V] | MutableSequence[V] the dict or list containing strings the quotes in which shall be removed Returns ------- - Union[MutableMapping[TKey, TValue], MutableSequence[TValue]] + MutableMapping[K, V] | MutableSequence[V] the original dict or list, yet with quotes in all strings being removed """ if isinstance(arg, MutableMapping): # Dict + arg = cast(M, arg) for key in list(arg.keys()): # work on a copy of keys if isinstance(arg[key], MutableMapping | MutableSequence): # dict or list arg[key] = Parser.remove_quotes_from_strings(arg[key]) # (recursion) elif isinstance(arg[key], str): # str arg[key] = Parser.remove_quotes_from_string(arg[key]) else: # List + arg = cast(S, arg) for index in range(len(arg)): if isinstance(arg[index], MutableMapping | MutableSequence): # dict or list arg[index] = Parser.remove_quotes_from_strings(arg[index]) # (recursion) @@ -1295,7 +1297,7 @@ def _insert_string_literals( # The entry from dict.string_literals is parsed once again, # so that entries representing single value native types # (such as bool ,None, int, float) are transformed to its native type, accordingly. - value = self.parse_value(string_literal) + value = cast(V, self.parse_value(string_literal)) # Replace all occurences of placeholder within the dictionary with the original string literal. # Note: As find_global_key() is non-greedy and returns the key of diff --git a/src/dictIO/utils/dict.py b/src/dictIO/utils/dict.py index e57a015c..762743e6 100644 --- a/src/dictIO/utils/dict.py +++ b/src/dictIO/utils/dict.py @@ -6,7 +6,7 @@ from copy import copy from typing import Any, cast -from dictIO.types import K, M, TKey, TValue, V +from dictIO.types import K, M, TKey, V def order_keys(arg: M) -> M: @@ -139,16 +139,16 @@ def set_global_key( def global_key_exists( - dict_in: MutableMapping[TKey, TValue], - global_key: MutableSequence[TKey], + dict_in: MutableMapping[K, V], + global_key: MutableSequence[K | int], ) -> bool: """Check whether the specified global key exists in the passed in dict. Parameters ---------- - dict_in : MutableMapping[TKey, TValue] + dict_in : MutableMapping[K, V] dict to check for existence of the specified global key - global_key : MutableSequence[TKey] + global_key : MutableSequence[K | int] global key the existence of which is checked in the passed in dict Returns @@ -156,11 +156,11 @@ def global_key_exists( bool True if the specified global key exists, otherwise False """ - _last_branch: MutableMapping[TKey, TValue] = dict_in - _next_branch: MutableMapping[TKey, TValue] | TValue + _last_branch: MutableMapping[K, V] = dict_in + _next_branch: MutableMapping[K, V] | V try: for key in global_key: - _next_branch = _last_branch[key] + _next_branch = _last_branch[key] # type: ignore[index, reportArgumentType] if not isinstance(_next_branch, MutableMapping): return False _last_branch = _next_branch diff --git a/tests/test_cpp_dict.py b/tests/test_cpp_dict.py index 26d448ee..bf88a231 100644 --- a/tests/test_cpp_dict.py +++ b/tests/test_cpp_dict.py @@ -2,7 +2,7 @@ from collections.abc import MutableMapping from copy import deepcopy from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import pytest @@ -17,11 +17,13 @@ order_keys, set_global_key, ) -from dictIO.types import TKey, TValue + +if TYPE_CHECKING: + from dictIO.types import TValue @pytest.fixture -def test_dict() -> SDict[TKey, TValue]: +def test_dict() -> SDict[str, Any]: parser = NativeParser() return parser.parse_file(Path("test_dict_dict")) @@ -326,7 +328,7 @@ def test_order_keys_of_test_dict(test_dict: CppDict) -> None: def test_reduce_scope_of_test_dict(test_dict: CppDict) -> None: # Prepare - scope: list[TKey] = ["scope", "subscope1"] + scope: list[str] = ["scope", "subscope1"] # Execute test_dict.reduce_scope(scope) # Assert diff --git a/tests/test_dict.py b/tests/test_dict.py index 3c4ed416..bce4622b 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -3,7 +3,7 @@ from collections.abc import MutableMapping from copy import deepcopy from pathlib import Path -from typing import Any +from typing import Any, cast import pytest @@ -18,17 +18,17 @@ order_keys, set_global_key, ) -from dictIO.types import TKey, TValue +from dictIO.types import K, TValue, V @pytest.fixture -def test_dict() -> SDict[TKey, TValue]: +def test_dict() -> SDict[str, Any]: parser = NativeParser() return parser.parse_file(Path("test_dict_dict")) def test_init() -> None: - test_dict: SDict[TKey, TValue] = SDict() + test_dict: SDict[str, Any] = SDict() assert test_dict.source_file is None assert test_dict.path == Path.cwd() assert test_dict.name == "" @@ -43,7 +43,7 @@ def test_init() -> None: def test_init_with_file() -> None: - test_dict: SDict[TKey, TValue] = SDict("someDict") + test_dict: SDict[str, Any] = SDict("someDict") assert test_dict.source_file == Path.cwd() / "someDict" assert test_dict.path == Path.cwd() assert test_dict.name == "someDict" @@ -58,11 +58,11 @@ def test_init_with_file() -> None: def test_init_with_base_dict() -> None: - base_dict: dict[TKey, TValue] = { + base_dict: dict[str, Any] = { "key1": "value1", "key2": "value2", } - test_dict: SDict[TKey, TValue] = SDict(base_dict) + test_dict: SDict[str, Any] = SDict(base_dict) assert test_dict == base_dict assert test_dict.source_file is None assert test_dict.path == Path.cwd() @@ -318,26 +318,26 @@ def test_order_keys() -> None: assert key == keys_expected_nested[index] -def test_order_keys_of_test_dict(test_dict: SDict[TKey, TValue]) -> None: +def test_order_keys_of_test_dict(test_dict: SDict[K, V]) -> None: # Prepare # Execute test_dict.order_keys() # Assert - assert str(test_dict["unordered"]) == str(test_dict["ordered"]) + assert str(test_dict[cast(K, "unordered")]) == str(test_dict[cast(K, "ordered")]) -def test_reduce_scope_of_test_dict(test_dict: SDict[TKey, TValue]) -> None: +def test_reduce_scope_of_test_dict(test_dict: SDict[K, V]) -> None: # Prepare - scope: list[TKey] = ["scope", "subscope1"] + scope: list[K] = [cast(K, "scope"), cast(K, "subscope1")] # Execute test_dict.reduce_scope(scope) # Assert dict_out = test_dict assert len(dict_out) == 2 # subscope11, subscope12 # assert dict_out["subscope11"] is not None - assert dict_out["subscope11"]["name"] == "subscope11" - assert dict_out["subscope12"] is not None - assert dict_out["subscope12"]["name"] == "subscope12" + assert dict_out[cast(K, "subscope11")]["name"] == "subscope11" + assert dict_out[cast(K, "subscope12")] is not None + assert dict_out[cast(K, "subscope12")]["name"] == "subscope12" def test_include() -> None: @@ -2023,9 +2023,9 @@ def test_sdict_copy_deepcopy() -> None: assert copied_dict.includes == original_dict.includes -def _construct_test_dict() -> dict[TKey, TValue]: +def _construct_test_dict() -> dict[str, Any]: # construct a test dict with single entries, a nested dict and a nested list - test_dict: dict[TKey, TValue] = { + test_dict: dict[str, Any] = { "A": "string 11", "B": 11, "C": 11.0, @@ -2046,9 +2046,9 @@ def _construct_test_dict() -> dict[TKey, TValue]: return test_dict -def _construct_test_sdict() -> SDict[TKey, TValue]: +def _construct_test_sdict() -> SDict[str, Any]: # construct a test SDict with single entries, a nested dict and a nested list - test_sdict: SDict[TKey, TValue] = SDict(_construct_test_dict()) + test_sdict: SDict[str, Any] = SDict(_construct_test_dict()) test_sdict.expressions |= { 1: { "name": "EXPRESSION000011", @@ -2108,7 +2108,7 @@ def test_load() -> None: def test_dump() -> None: # Prepare target_file: Path = Path("temp_file_test_write_dict") - test_dict: dict[TKey, TValue] = { + test_dict: dict[str, Any] = { "param1": -10.0, "param2": 0.0, "param3": 0.0, diff --git a/tests/test_formatter.py b/tests/test_formatter.py index af52de71..b8ab86ca 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -307,7 +307,7 @@ def test_remove_trailing_spaces(self) -> None: def test_list_with_nested_list(self) -> None: # Prepare - test_obj: dict[TKey, TValue] = { + test_obj: dict[str, Any] = { "blocks": [ "hex", [0, 1, 2, 3, 4, 5, 6, 7], diff --git a/tests/test_parser.py b/tests/test_parser.py index 7af2f282..a47b3beb 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -10,7 +10,7 @@ import pytest from dictIO import NativeParser, Parser, SDict, XmlParser -from dictIO.types import TKey, TValue +from dictIO.types import K, TKey, TValue, V from dictIO.utils.counter import BorgCounter from dictIO.utils.strings import string_diff @@ -1152,7 +1152,7 @@ def test_parse_xml_namespace_default(self) -> None: class SetupHelper: @staticmethod def prepare_dict_until( - dict_to_prepare: SDict[TKey, TValue], + dict_to_prepare: SDict[K, V], until_step: int = -1, file_to_read: str = "test_parser_dict", *,