Skip to content

Commit

Permalink
improve typing (work in progress)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaasRostock committed Nov 11, 2024
1 parent 143705d commit 2275b7c
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 178 deletions.
10 changes: 5 additions & 5 deletions src/dictIO/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,11 @@ def merge(self, other: Mapping[K, V]) -> None:
Parameters
----------
other : MutableMapping[TKey, TValue]
other : Mapping[K, V]
dict to be merged
"""
# merge other dict into self (=into self)
self._recursive_merge(self, other)
self._recursive_merge(target_dict=self, dict_to_merge=other)
# merge SDict attributes
self._post_merge(other)
self._clean()
Expand Down Expand Up @@ -859,7 +859,7 @@ def find_global_key(self, query: str = "") -> list[TGlobalKey] | None:
"""
return find_global_key(cast(SDict[TKey, TValue], self), query)

def set_global_key(self, global_key: MutableSequence[TKey], value: TValue = None) -> None:
def set_global_key(self, global_key: MutableSequence[TKey], value: TValue) -> None:
"""Set the value for the passed in global key.
The global key thread is traversed downwards until arrival at the target key,
Expand All @@ -869,8 +869,8 @@ def set_global_key(self, global_key: MutableSequence[TKey], value: TValue = None
----------
global_key : MutableSequence[TValue]
list of keys defining the global key thread to the target key (such as returned by method find_global_key())
value : TValue, optional
value the target key shall be set to, by default None
value : TValue
value the target key shall be set to
"""
set_global_key(
arg=cast(MutableMapping[TKey, TValue], self),
Expand Down
37 changes: 22 additions & 15 deletions src/dictIO/dict_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)

from dictIO import Parser, SDict
from dictIO.types import TKey, TValue
from dictIO.types import K, TKey, TValue, V
from dictIO.utils.counter import DejaVue

__ALL__ = ["DictReader"]
Expand Down Expand Up @@ -114,7 +114,7 @@ def read(
parser = parser or Parser.get_parser(source_file)

# Parse the dict file and transform it into a SDict
parsed_dict = parser.parse_file(source_file, comments=comments)
parsed_dict: SDict[TKey, TValue] = parser.parse_file(source_file, comments=comments)

# Merge dict files included through #include directives, if not actively refrained through opts
if includes:
Expand Down Expand Up @@ -149,7 +149,7 @@ def read(

@staticmethod
def _merge_includes(
parent_dict: SDict[TKey, TValue],
parent_dict: SDict[K, V],
*,
comments: bool = True,
) -> None:
Expand All @@ -159,9 +159,9 @@ def _merge_includes(
djv.reset()

# Inner function: Merge all includes, recursively
def _merge_includes_recursive(parent_dict: SDict[TKey, TValue]) -> SDict[TKey, TValue]:
def _merge_includes_recursive(parent_dict: SDict[K, V]) -> SDict[K, V]:
# empty dict to merge in temporarily, avoiding dict-has-change-error inside the for loop
temp_dict: SDict[TKey, TValue] = SDict()
temp_dict: SDict[K, V] = SDict()

# loop over all possible includes
for _, _, path in parent_dict.includes.values():
Expand All @@ -176,11 +176,18 @@ def _merge_includes_recursive(parent_dict: SDict[TKey, TValue]) -> SDict[TKey, T
logger.warning(f"included dict not found. Merging of {path} aborted.")
else:
parser = Parser.get_parser(source_file=path)
included_dict = parser.parse_file(source_file=path, target_dict=None, comments=comments)
included_dict = cast(
SDict[K, V],
parser.parse_file(
source_file=path,
target_dict=None,
comments=comments,
),
)

# recursion in case the i-th include also has includes
if len(included_dict.includes) != 0:
nested_included_dict = _merge_includes_recursive(included_dict)
nested_included_dict = _merge_includes_recursive(parent_dict=included_dict)
# merge second level
temp_dict.merge(nested_included_dict)

Expand All @@ -193,7 +200,7 @@ def _merge_includes_recursive(parent_dict: SDict[TKey, TValue]) -> SDict[TKey, T
return parent_dict

# Call inner funtion to merge all includes, recursively
parent_dict.merge(_merge_includes_recursive(parent_dict))
parent_dict.merge(_merge_includes_recursive(parent_dict=parent_dict))

return

Expand Down Expand Up @@ -231,7 +238,7 @@ def _resolve_reference(
return value

@staticmethod
def _eval_expressions(dict_in: SDict[TKey, TValue]) -> None:
def _eval_expressions(dict_in: SDict[K, V]) -> None:
# Collect all references contained in expressions
_references: list[str] = []
_refs: list[str]
Expand All @@ -241,15 +248,15 @@ def _eval_expressions(dict_in: SDict[TKey, TValue]) -> None:
_refs = re.findall(pattern=r"\$\w[\w\[\]]*", string=item["expression"])
_references.extend(_refs)
# Resolve references
variables: dict[str, TValue] = dict_in.variables
references: dict[str, TValue] = {
variables: dict[str, V] = dict_in.variables
references: dict[str, V] = {
ref: DictReader._resolve_reference(
reference=ref,
variables=variables,
)
for ref in _references
}
references_resolved: dict[str, TValue] = {
references_resolved: dict[str, V] = {
ref: value
for ref, value in references.items()
if (value is not None) and (not re.search(pattern=r"EXPRESSION|\$", string=str(value)))
Expand All @@ -270,19 +277,19 @@ def _eval_expressions(dict_in: SDict[TKey, TValue]) -> None:
for ref in _refs:
if ref in references_resolved:
expression = re.sub(
pattern=f"{re.escape(ref)}",
pattern=f"{re.escape(pattern=ref)}",
repl=str(references_resolved[ref]),
string=expression,
)

eval_successful: bool = False
eval_result: TValue | None = None
eval_result: V | None = None
if "$" not in expression:
try:
eval_result = eval(expression) # noqa: S307
eval_successful = True
except NameError:
eval_result = expression
eval_result = cast(V, expression)
eval_successful = True
except SyntaxError:
logger.warning(f'DictReader.(): evaluation of "{expression}" not yet possible')
Expand Down
13 changes: 7 additions & 6 deletions src/dictIO/dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import re
from collections.abc import MutableMapping, MutableSequence
from pathlib import Path
from typing import cast

from dictIO import Formatter, NativeParser, SDict, order_keys
from dictIO.types import TKey, TValue
from dictIO.types import K, TKey, V

__ALL__ = ["DictWriter", "create_target_file_name"]

Expand All @@ -22,7 +23,7 @@ def __init__(self) -> None:

@staticmethod
def write(
source_dict: MutableMapping[TKey, TValue],
source_dict: MutableMapping[K, V],
target_file: str | os.PathLike[str] | None = None,
mode: str = "a",
*,
Expand All @@ -45,15 +46,15 @@ def write(
Parameters
----------
source_dict : Union[MutableMapping[TKey, TValue], SDict]
source_dict : MutableMapping[K, V]
source dict
target_file : Union[str, os.PathLike[str], None], optional
target_file : str | os.PathLike[str] | None, optional
target dict file name, by default None
mode : str, optional
append to target file ('a') or overwrite target file ('w'), by default 'a'
order : bool, optional
if True, the dict will be sorted before writing, by default False
formatter : Union[Formatter, None], optional
formatter : Formatter | None, optional
formatter to be used, by default None
"""
# Check arguments
Expand Down Expand Up @@ -93,7 +94,7 @@ def write(
)
from dictIO import DictReader

existing_dict = DictReader.read(target_file, order=order)
existing_dict = cast(SDict[K, V], DictReader.read(source_file=target_file, order=order))
existing_dict.merge(source_dict)
source_dict = existing_dict

Expand Down
24 changes: 13 additions & 11 deletions src/dictIO/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from numpy import ndarray

from dictIO import SDict
from dictIO.types import M, S, TKey, TSingleValue, TValue
from dictIO.types import K, M, S, TKey, TValue, V
from dictIO.utils.counter import BorgCounter

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,15 +79,15 @@ def get_formatter(cls, target_file: Path | None = None) -> Formatter:
@abstractmethod
def to_string(
self,
arg: MutableMapping[TKey, TValue] | SDict[TKey, TValue],
arg: MutableMapping[K, V],
) -> str:
"""Create a string representation of the passed in dict.
Note: Override this method when implementing a specific Formatter.
Parameters
----------
arg : Union[MutableMapping[TKey, TValue], SDict]
arg : Union[MutableMapping[K, V]]
dict to be formatted
Returns
Expand All @@ -99,20 +99,20 @@ def to_string(

def format_value(
self,
arg: TSingleValue | TValue,
) -> str | TValue:
arg: V,
) -> str | V:
"""Format a single value.
Formats a single value of type TSingleValue = str | int | float | bool | None
Parameters
----------
arg : TSingleValue | TValue
arg : V
the value to be formatted
Returns
-------
str | TValue
str | V
the formatted string representation of the passed in value,
if value is of a single value type. Otherwise the value itself.
"""
Expand All @@ -130,7 +130,7 @@ def format_value(
return self.format_float(arg)

# If arg is not of a single value type, return it as is.
return arg
return cast(V, arg)

@overload
def format_values(
Expand Down Expand Up @@ -160,12 +160,12 @@ def format_values(
Parameters
----------
arg : Union[MutableMapping[TKey, TValue], MutableSequence[TValue]]
arg : MutableMapping[K, V] | MutableSequence[V]
the dict or list containing the values to be formatted.
Returns
-------
MutableMapping[TKey, str] | MutableSequence[str]
MutableMapping[K, V] | MutableSequence[V]
a copy of the passed in dict or list, with all values formatted.
"""
item: TValue
Expand Down Expand Up @@ -213,7 +213,9 @@ def format_key(
the formatted string representation of the passed in key
"""
skey: str
skey = self.format_value(arg) if isinstance(arg, TSingleValue) else str(arg)
key = self.format_value(arg)

skey = key if isinstance(key, str) else str(key)
return skey

def format_bool(self, arg: bool) -> str: # noqa: FBT001
Expand Down
Loading

0 comments on commit 2275b7c

Please sign in to comment.