Skip to content

Commit

Permalink
improve typing (work in progress)
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaasRostock committed Nov 11, 2024
1 parent 2275b7c commit f2df98d
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 119 deletions.
104 changes: 53 additions & 51 deletions src/dictIO/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from copy import copy
from pathlib import Path
from typing import (
Any,
TypeVar,
cast,
overload,
)

from dictIO.types import K, TGlobalKey, TKey, TValue, V
from dictIO.types import K, TKey, TValue, V
from dictIO.utils.counter import BorgCounter
from dictIO.utils.dict import (
find_global_key,
Expand Down Expand Up @@ -160,7 +161,7 @@ def fromkeys(
cls,
iterable: Iterable[_K],
value: None = None,
) -> SDict[_K, TValue | None]:
) -> SDict[_K, Any | None]:
pass

@overload
Expand All @@ -177,7 +178,7 @@ def fromkeys(
cls,
iterable: Iterable[_K],
value: _V | None = None,
) -> SDict[_K, _V] | SDict[_K, TValue | None]:
) -> SDict[_K, _V] | SDict[_K, Any | None]:
"""Create a new SDict instance from the keys of an iterable.

Parameters
Expand All @@ -189,7 +190,7 @@ def fromkeys(

Returns
-------
SDict[_K, _V] | SDict[_K, TValue | None]
SDict[_K, _V] | SDict[_K, Any | None]
The created SDict instance.
"""
new_dict: SDict[_K, _V] = cast(SDict[_K, _V], cls())
Expand Down Expand Up @@ -245,7 +246,7 @@ def load(

from dictIO.dict_reader import DictReader

loaded_dict: SDict[TKey, TValue] = DictReader.read(
loaded_dict: SDict[K, V] = DictReader.read(
source_file=source_file,
)
self.reset()
Expand All @@ -254,7 +255,7 @@ def load(
# Maybe this method needs to be refactored to a factory function, returning a new `SDict` instance
# with the actual types of `loaded_dict`.
# CLAROS, 2024-11-06
self.update(cast(SDict[K, V], loaded_dict))
self.update(loaded_dict)
self._set_source_file(source_file)
return self

Expand Down Expand Up @@ -299,10 +300,10 @@ def dump(
from dictIO.dict_writer import DictWriter

DictWriter.write(
source_dict=cast(SDict[TKey, TValue], self),
source_dict=self,
target_file=target_file,
)
self._set_source_file(target_file)
self._set_source_file(source_file=target_file)

return target_file

Expand Down Expand Up @@ -355,15 +356,15 @@ def name(self) -> str:
return self._name

@property
def variables(self) -> dict[str, TValue]:
def variables(self) -> dict[str, V]:
"""Returns a dict with all Variables currently registered.

Returns
-------
Dict[str, TValue]
Dict[str, V]
dict of all Variables currently registered.
"""
variables: MutableMapping[str, TValue] = {}
variables: MutableMapping[str, V] = {}

def extract_variables_from_dict(dict_in: MutableMapping[_K, V]) -> None:
for key, value in dict_in.items():
Expand All @@ -379,15 +380,16 @@ def extract_variables_from_dict(dict_in: MutableMapping[_K, V]) -> None:
continue
if isinstance(value, MutableSequence):
# special case: item is a list, but does NOT contain a nested dict (-> e.g. a vector or matrix)
variables[key] = value
variables[key] = cast(V, value)
else:
# base case: item is a single value type
_value = _insert_expression(value, self)
value = cast(V, value)
_value = _insert_expression(value=value, s_dict=self)
if not _value_contains_circular_reference(key, _value):
variables[key] = _value
return

def extract_variables_from_list(list_in: MutableSequence[TValue]) -> None:
def extract_variables_from_list(list_in: MutableSequence[V]) -> None:
# sourcery skip: remove-redundant-pass
for value in list_in:
if isinstance(value, MutableMapping):
Expand All @@ -401,7 +403,7 @@ def extract_variables_from_list(list_in: MutableSequence[TValue]) -> None:
pass
return

def list_contains_dict(list_in: MutableSequence[TValue]) -> bool:
def list_contains_dict(list_in: MutableSequence[V]) -> bool:
# sourcery skip: merge-duplicate-blocks, use-any
for value in list_in:
if isinstance(value, MutableMapping):
Expand Down Expand Up @@ -462,9 +464,9 @@ class attributes in the update.

Parameters
----------
__m : Mapping[TKey, TValue]
m : Mapping[K, V] | Iterable[tuple[K, V]] | None
dict containing the keys to be updated and its new values
**kwargs: TValue
**kwargs: V
optional keyword arguments. These will be passed on to the update() method of the parent class.
"""
if m is None:
Expand Down Expand Up @@ -526,28 +528,28 @@ def _recursive_merge(

Parameters
----------
target_dict : MutableMapping[TKey, TValue] | MutableMapping[int, TValue]
target_dict : MutableMapping[_K, _V]
target dict
dict_to_merge : Mapping[TKey, TValue] | Mapping[int, TValue]
dict_to_merge : Mapping[_K, _V]
dict to be merged into target dict
overwrite : bool, optional
if True, existing keys will be overwritten, by default False
"""
for key in dict_to_merge:
if (
key in target_dict
and isinstance(target_dict[key], MutableMapping) # pyright: ignore[reportArgumentType]
and isinstance(dict_to_merge[key], Mapping) # pyright: ignore[reportArgumentType]
and isinstance(target_dict[key], MutableMapping)
and isinstance(dict_to_merge[key], Mapping)
): # dict
self._recursive_merge( # Recursion
target_dict=cast(MutableMapping[TKey, TValue], target_dict[key]),
dict_to_merge=cast(Mapping[TKey, TValue], dict_to_merge[key]),
target_dict=cast(MutableMapping[K, V], target_dict[key]),
dict_to_merge=cast(Mapping[K, V], dict_to_merge[key]),
overwrite=overwrite,
)
else:
value_in_target_dict_contains_circular_reference = False
if isinstance(target_dict, SDict) and key in target_dict:
value = _insert_expression(target_dict[key], target_dict)
value = _insert_expression(value=target_dict[key], s_dict=target_dict)
value_in_target_dict_contains_circular_reference = _value_contains_circular_reference(key, value)
if overwrite or key not in target_dict or value_in_target_dict_contains_circular_reference:
target_dict[key] = dict_to_merge[key] # Update
Expand Down Expand Up @@ -794,12 +796,11 @@ def __str__(self) -> str:
str
the string representation
"""
from dictIO import (
NativeFormatter, # __str__ shall be formatted in dictIO native file format
)
# __str__ shall be formatted in dictIO native file format
from dictIO import NativeFormatter

formatter = NativeFormatter()
return formatter.to_string(cast(SDict[TKey, TValue], self))
return formatter.to_string(self)

def __repr__(self) -> str:
"""Return a string representation of the SDict instance.
Expand Down Expand Up @@ -839,7 +840,7 @@ def order_keys(self) -> None:
self.includes = order_keys(self.includes)
return

def find_global_key(self, query: str = "") -> list[TGlobalKey] | None:
def find_global_key(self, query: str = "") -> list[K | int] | None:
"""Return the global key thread to the first key the value of which matches the passed in query.

Function works recursively on nested dicts and is non-greedy: The key of the first match is returned.
Expand All @@ -854,38 +855,38 @@ def find_global_key(self, query: str = "") -> list[TGlobalKey] | None:

Returns
-------
Union[list[TKey], None]
list[K | int] | None
global key thread to the first key the value of which matches the passed in query, if found. Otherwise None.
"""
return find_global_key(cast(SDict[TKey, TValue], self), query)
return find_global_key(self, query)

def set_global_key(self, global_key: MutableSequence[TKey], value: TValue) -> None:
def set_global_key(self, global_key: MutableSequence[K | int], value: V) -> None:
"""Set the value for the passed in global key.

The global key thread is traversed downwards until arrival at the target key,
the value of which is then set.

Parameters
----------
global_key : MutableSequence[TValue]
global_key : MutableSequence[K | int]
list of keys defining the global key thread to the target key (such as returned by method find_global_key())
value : TValue
value : V
value the target key shall be set to
"""
set_global_key(
arg=cast(MutableMapping[TKey, TValue], self),
arg=self,
global_key=global_key,
value=value,
)

return

def global_key_exists(self, global_key: MutableSequence[TKey]) -> bool:
def global_key_exists(self, global_key: MutableSequence[K | int]) -> bool:
"""Check whether the specified global key exists.

Parameters
----------
global_key : MutableSequence[TValue]
global_key : MutableSequence[K | int]
global key the existence of which is checked

Returns
Expand All @@ -897,16 +898,16 @@ def global_key_exists(self, global_key: MutableSequence[TKey]) -> bool:
probe the existence of (nested) keys in dict
"""
return global_key_exists(
dict_in=cast(MutableMapping[TKey, TValue], self),
dict_in=self,
global_key=global_key,
)

def reduce_scope(self, scope: MutableSequence[TKey]) -> None:
def reduce_scope(self, scope: MutableSequence[K]) -> None:
"""Reduces the dict to the keys defined in scope.

Parameters
----------
scope : MutableSequence[str]
scope : MutableSequence[K]
scope the dict shall be reduced to
"""
if scope:
Expand Down Expand Up @@ -953,7 +954,7 @@ def _clean(self) -> None:
Doublettes are identified through equality with their lookup values.
"""

def _recursive_clean(data: MutableMapping[TKey, TValue]) -> None:
def _recursive_clean(data: MutableMapping[K, V]) -> None:
self._clean_data(
data=data,
)
Expand All @@ -964,11 +965,11 @@ def _recursive_clean(data: MutableMapping[TKey, TValue]) -> None:

return

_recursive_clean(data=cast(MutableMapping[TKey, TValue], self))
_recursive_clean(data=self)

return

def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None:
def _clean_data(self, data: MutableMapping[K, V]) -> None:
"""Find and remove doublettes of PLACEHOLDER keys.

Find and remove doublettes of following PLACEHOLDER keys within data:
Expand All @@ -979,7 +980,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None:
Doublettes are identified through equality with their lookup values.
"""
# IDENTIFY all placeholders on current level
keys_on_this_level: list[TKey] = list(data)
keys_on_this_level: list[K] = list(data)
block_comments_on_this_level: list[str] = []
includes_on_this_level: list[str] = []
line_comments_on_this_level: list[str] = []
Expand All @@ -1001,7 +1002,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None:
if block_comment in unique_block_comments_on_this_level:
# Found doublette
# Remove from current level in data (the dict)
del data[_block_comment]
del data[cast(K, _block_comment)]
# ..AND from self.block_comments (the lookup table)
del self.block_comments[_id]
else:
Expand All @@ -1015,7 +1016,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None:
if include in unique_includes_on_this_level:
# Found doublette
# Remove from current level in data (the dict)
del data[_include]
del data[cast(K, _include)]
# ..AND from self.includes (the lookup table)
del self.includes[_id]
else:
Expand All @@ -1029,7 +1030,7 @@ def _clean_data(self, data: MutableMapping[TKey, TValue]) -> None:
if line_comment in unique_line_comments_on_this_level:
# Found doublette
# Remove from current level in data (the dict)
del data[_line_comment]
del data[cast(K, _line_comment)]
# ..AND from self.line_comments (the lookup table)
del self.line_comments[_id]
else:
Expand Down Expand Up @@ -1071,15 +1072,16 @@ def data(self, data: dict[K, V]) -> None:
return


def _insert_expression(value: TValue, s_dict: SDict[K, V]) -> TValue:
def _insert_expression(value: V, s_dict: SDict[K, V]) -> V:
if not isinstance(value, str):
return value
if not re.search(r"EXPRESSION\d{6}", value):
return value
return cast(V, value)
if match_index := re.search(r"\d{6}", value):
index = int(match_index[0])
return s_dict.expressions[index]["expression"] if index in s_dict.expressions else value
return value
_value = s_dict.expressions[index]["expression"] if index in s_dict.expressions else value
return cast(V, _value)
return cast(V, value)


def _value_contains_circular_reference(key: TKey, value: TValue) -> bool:
Expand Down
Loading

0 comments on commit f2df98d

Please sign in to comment.