diff --git a/src/pyserials/nested_dict.py b/src/pyserials/nested_dict.py index 42e72c3..ba94e23 100644 --- a/src/pyserials/nested_dict.py +++ b/src/pyserials/nested_dict.py @@ -24,15 +24,17 @@ def __init__( end_list: str = "]]$", end_unpack: str = "}}*", end_code: str = "}}#", - recursive: bool = True, raise_no_match: bool = True, leave_no_match: bool = False, no_match_value: Any = None, code_context: dict[str, Any] | None = None, + code_context_partial: dict[str, Callable | tuple[Callable, str]] | None = None, + code_context_call: dict[str, Callable[[Callable], Any]] | None = None, stringer: Callable[[str], str] = str, unpack_string_joiner: str = ", ", relative_template_keys: list[str] | None = None, implicit_root: bool = True, + getter_function_name: str = "get", ): self._data = data or {} self._templater = _ps.update.TemplateFiller( @@ -47,29 +49,21 @@ def __init__( end_list=end_list, end_unpack=end_unpack, end_code=end_code, + raise_no_match=raise_no_match, + leave_no_match=leave_no_match, + no_match_value=no_match_value, + code_context=code_context, + code_context_partial=code_context_partial, + code_context_call=code_context_call, + stringer=stringer, + unpack_string_joiner=unpack_string_joiner, + relative_template_keys=relative_template_keys, + implicit_root=implicit_root, + getter_function_name=getter_function_name, ) - self._recursive = recursive - self._raise_no_match = raise_no_match - self._leave_no_match = leave_no_match - self._no_match_value = no_match_value - self._code_context = code_context or {} - self._stringer = stringer - self._unpack_string_joiner = unpack_string_joiner - self._relative_template_keys = relative_template_keys or [] - self._implicit_root = implicit_root return - def fill( - self, - path: str = "", - recursive: bool | None = None, - raise_no_match: bool | None = None, - leave_no_match: bool | None = None, - code_context: dict[str, Any] | None = None, - stringer: Callable[[str], str] | None = None, - unpack_string_joiner: str | None = None, - level: int = 0, - ): + def fill(self, path: str = ""): if not path: value = self._data else: @@ -79,13 +73,6 @@ def fill( filled_value = self.fill_data( data=value, current_path=path, - recursive=recursive, - raise_no_match=raise_no_match, - leave_no_match=leave_no_match, - code_context=code_context, - stringer=stringer, - unpack_string_joiner=unpack_string_joiner, - level=level, ) if not path: self._data = filled_value @@ -93,32 +80,11 @@ def fill( self.__setitem__(path, filled_value) return filled_value - def fill_data( - self, - data, - current_path: str = "", - recursive: bool | None = None, - raise_no_match: bool | None = None, - leave_no_match: bool | None = None, - stringer: Callable[[str], str] | None = None, - code_context: dict[str, Any] | None = None, - unpack_string_joiner: str | None = None, - level: int = 0, - ): + def fill_data(self, data, current_path: str = ""): return self._templater.fill( - templated_data=data, - source_data=self._data, + data=self._data, + template=data, current_path=current_path, - recursive=recursive if recursive is not None else self._recursive, - raise_no_match=raise_no_match if raise_no_match is not None else self._raise_no_match, - leave_no_match=leave_no_match if leave_no_match is not None else self._leave_no_match, - no_match_value=self._no_match_value, - code_context=code_context if code_context is not None else self._code_context, - stringer=stringer if stringer is not None else self._stringer, - unpack_string_joiner=unpack_string_joiner if unpack_string_joiner is not None else self._unpack_string_joiner, - relative_template_keys=self._relative_template_keys, - implicit_root=self._implicit_root, - level=level, ) def __call__(self): diff --git a/src/pyserials/update.py b/src/pyserials/update.py index de8c346..919619b 100644 --- a/src/pyserials/update.py +++ b/src/pyserials/update.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING as _TYPE_CHECKING import re as _re +from functools import partial as _partial import jsonpath_ng as _jsonpath from jsonpath_ng import exceptions as _jsonpath_exceptions @@ -116,6 +117,17 @@ def __init__( end_list: str = "]]$", end_unpack: str = "}}*", end_code: str = "}}#", + raise_no_match: bool = True, + leave_no_match: bool = False, + no_match_value: Any = None, + code_context: dict[str, Any] | None = None, + code_context_partial: dict[str, Callable | tuple[Callable, str]] | None = None, + code_context_call: dict[str, Callable[[Callable], Any]] | None = None, + stringer: Callable[[str], str] = str, + unpack_string_joiner: str = ", ", + relative_template_keys: list[str] | None = None, + implicit_root: bool = True, + getter_function_name: str = "get", ): self._marker_start_value = marker_start_value self._marker_end_value = marker_end_value @@ -125,99 +137,64 @@ def __init__( self._pattern_list = _RegexPattern(start=start_list, end=end_list) self._pattern_unpack = _RegexPattern(start=start_unpack, end=end_unpack) self._pattern_code = _RegexPattern(start=start_code, end=end_code) - self._add_prefix = True - - self._pattern_value: dict[int, _RegexPattern] = {} - self._data = None - self._source = None - self._recursive = None - self._path = None - self._raise_no_match = None - self._template_keys = None - self._ignore_templates = True - self._leave_no_match = False - self._no_match_value = None - self._code_context = {} - self._stringer = str - self._unpack_string_joiner = ", " - self._path_history = [] - return - - def _get_value_regex_pattern(self, level: int = 0) -> _RegexPattern: - level_patterns = self._pattern_value.setdefault(level, {}) - if level in level_patterns: - return level_patterns[level] - count = self._repeater_count_value + level - pattern = _RegexPattern( - start=f"{self._marker_start_value}{self._repeater_start_value * count} ", - end=f" {self._repeater_end_value * count}{self._marker_end_value}", - ) - level_patterns[level] = pattern - return pattern - def fill( - self, - templated_data: dict | list | str, - source_data: dict | list, - current_path: str = "", - recursive: bool = True, - raise_no_match: bool = True, - leave_no_match: bool = False, - no_match_value: Any = None, - code_context: dict[str, Any] | None = None, - stringer: Callable[[str], str] = str, - unpack_string_joiner: str = ", ", - relative_template_keys: list[str] | None = None, - implicit_root: bool = True, - level: int = 0, - ): - self._data = templated_data - self._source = source_data - self._recursive = recursive self._raise_no_match = raise_no_match self._leave_no_match = leave_no_match self._no_match_value = no_match_value self._code_context = code_context or {} + self._code_context_partial = code_context_partial or {} + self._code_context_call = code_context_call or {} self._stringer = stringer self._unpack_string_joiner = unpack_string_joiner self._add_prefix = implicit_root self._template_keys = relative_template_keys or [] - self._path_history = [] + self._getter_function_name = getter_function_name + + self._pattern_value: dict[int, _RegexPattern] = {} + self._data = None + return + + def fill( + self, + data: dict | list, + template: dict | list | str | None = None, + current_path: str = "", + ): + self._data = data path = (f"$.{current_path}" if self._add_prefix else current_path) if current_path else "$" - if not relative_template_keys: - self._ignore_templates = False - return self._recursive_subst( - templ=self._data, - current_path=path, - relative_path_anchor=path, - level=level, - ) - self._ignore_templates = True - first_pass = self._recursive_subst( - templ=self._data, - current_path=path, - relative_path_anchor=path, - level=level, - ) - if self._data is self._source: - self._source = first_pass - self._data = first_pass - self._ignore_templates = False - self._path_history = [] return self._recursive_subst( - templ=self._data, + templ=template or data, current_path=path, relative_path_anchor=path, - level=level, + level=0, + current_chain=[path], ) - def _recursive_subst(self, templ, current_path: str, relative_path_anchor: str, level: int, internal=False): + def _recursive_subst(self, templ, current_path: str, relative_path_anchor: str, level: int, current_chain: list[str]): + + def get_code_value(match: _re.Match | str): + + def getter_function(path: str, default: Any = None, search: bool = False): + value, matched = get_address_value(path, return_all_matches=search, from_code=True) + if matched: + return value + if search: + return [] + return default - def get_code_value(code_str: str): + code_str = match if isinstance(match, str) else match.group(1) code_lines = ["def __inline_code__():"] code_lines.extend([f" {line}" for line in code_str.strip("\n").splitlines()]) code_str_full = "\n".join(code_lines) - global_context = self._code_context.copy() + global_context = self._code_context.copy() | {self._getter_function_name: getter_function} + for name, partial_func_data in self._code_context_partial.items(): + if isinstance(partial_func_data, tuple): + func, arg_name = partial_func_data + global_context[name] = _partial(func, **{arg_name: getter_function}) + else: + global_context[name] = _partial(partial_func_data, getter_function) + for name, call_func in self._code_context_call.items(): + global_context[name] = call_func(getter_function) local_context = {} try: exec(code_str_full, global_context, local_context) @@ -228,8 +205,9 @@ def get_code_value(code_str: str): path_invalid=current_path, ) - def get_address_value(re_match, return_all_matches: bool = False): - path, num_periods = self._remove_leading_periods(re_match.group(1).strip()) + def get_address_value(match: _re.Match | str, return_all_matches: bool = False, from_code: bool = False): + raw_path = match if isinstance(match, str) else str(match.group(1)) + path, num_periods = self._remove_leading_periods(raw_path.strip()) if num_periods == 0: path = f"$.{path}" if self._add_prefix else path try: @@ -239,13 +217,14 @@ def get_address_value(re_match, return_all_matches: bool = False): path_invalid=path, description_template="JSONPath expression {path_invalid} is invalid.", ) - if self._ignore_templates: - path_fields = self._extract_fields(path_expr) - has_template_key = any(field in self._template_keys for field in path_fields) - if has_template_key: - return re_match.string if num_periods: - root_path_expr = _jsonpath.parse(relative_path_anchor) + if relative_path_anchor != current_path: + path_fields = self._extract_fields(_jsonpath.parse(current_path)) + has_template_key = any(field in self._template_keys for field in path_fields) + anchor_path = relative_path_anchor if has_template_key else current_path + else: + anchor_path = current_path + root_path_expr = _jsonpath.parse(anchor_path) for period in range(num_periods): if isinstance(root_path_expr, _jsonpath.Root): raise_error( @@ -256,17 +235,21 @@ def get_address_value(re_match, return_all_matches: bool = False): ), ) root_path_expr = root_path_expr.left - path_expr = _jsonpath.Child(root_path_expr, path_expr) - value, matched = get_value(path_expr, return_all_matches) + path_expr = self._concat_json_paths(root_path_expr, path_expr) + value, matched = get_value(path_expr, return_all_matches, from_code) + if from_code: + return value, matched if matched: return value if self._leave_no_match: - return re_match.group() + return match.group() return self._no_match_value - def get_value(jsonpath, return_all_matches: bool) -> tuple[Any, bool]: + def get_value(jsonpath, return_all_matches: bool, from_code: bool) -> tuple[Any, bool]: matches = _rec_match(jsonpath) if not matches: + if from_code: + return None, False if return_all_matches: return [], True if self._raise_no_match: @@ -277,8 +260,6 @@ def get_value(jsonpath, return_all_matches: bool) -> tuple[Any, bool]: return None, False values = [m.value for m in matches] output = values if return_all_matches or len(values) > 1 else values[0] - if not self._recursive: - return output, True if relative_path_anchor == current_path: path_fields = self._extract_fields(jsonpath) has_template_key = any(field in self._template_keys for field in path_fields) @@ -290,10 +271,11 @@ def get_value(jsonpath, return_all_matches: bool) -> tuple[Any, bool]: current_path=str(jsonpath), relative_path_anchor=_rel_path_anchor, level=0, + current_chain=current_chain + [str(jsonpath)], ), True def _rec_match(expr) -> list: - matches = expr.find(self._source) + matches = expr.find(self._data) if matches: return matches if isinstance(expr.left, _jsonpath.Root): @@ -306,6 +288,7 @@ def _rec_match(expr) -> list: current_path=str(expr.left), relative_path_anchor=str(expr.left), level=0, + current_chain=current_chain + [str(expr.left)], ) if isinstance(left_match.value, str) else left_match.value right_matches = expr.right.find(left_match_filled) whole_matches.extend(right_matches) @@ -314,33 +297,44 @@ def _rec_match(expr) -> list: def get_relative_path(new_path): return new_path if current_path == relative_path_anchor else relative_path_anchor - def raise_error( - path_invalid: str, - description_template: str, - ): - raise _exception.update.PySerialsUpdateTemplatedDataError( - description_template=description_template, - path_invalid=path_invalid, - path=current_path, - data=templ, - data_full=self._data, - data_source=self._source, - template_start=self._marker_start_value, - template_end=self._marker_end_value, + def fill_nested_values(match: _re.Match | str): + pattern_nested = self._get_value_regex_pattern(level=level + 1) + return pattern_nested.sub( + lambda x: self._recursive_subst( + templ=x.group(), + current_path=current_path, + relative_path_anchor=get_relative_path(current_path), + level=level + 1, + current_chain=current_chain, + ), + match if isinstance(match, str) else match.group(1), ) def string_filler_unpack(match: _re.Match): - match_list = self._pattern_list.fullmatch(match.group(1).strip()) + path = str(match.group(1)).strip() + match_list = self._pattern_list.fullmatch(path) if match_list: values = get_address_value(match_list, return_all_matches=True) else: - match_code = self._pattern_code.fullmatch(match.group(1).strip()) + match_code = self._pattern_code.fullmatch(path) if match_code: - values = get_code_value(match_code.group(1)) + values = get_code_value(match_code) else: - values = get_address_value(match) + values = get_address_value(path) return self._unpack_string_joiner.join([self._stringer(val) for val in values]) + def raise_error(path_invalid: str, description_template: str): + raise _exception.update.PySerialsUpdateTemplatedDataError( + description_template=description_template, + path_invalid=path_invalid, + path=current_path, + data=templ, + data_full=self._data, + data_source=self._data, + template_start=self._marker_start_value, + template_end=self._marker_end_value, + ) + # if not internal: # self._path_history.append(current_path) # loop = self._find_loop() @@ -358,49 +352,47 @@ def string_filler_unpack(match: _re.Match): # ) if isinstance(templ, str): - pattern_nested = self._get_value_regex_pattern(level=level + 1) - templ_nested_filled = pattern_nested.sub( - lambda x: self._recursive_subst( - templ=x.group(), - current_path=current_path, - relative_path_anchor=get_relative_path(current_path), - level=level+1, - internal=True, - ), - templ - ) + # Handle value blocks pattern_value = self._get_value_regex_pattern(level=level) - whole_match_value = pattern_value.fullmatch(templ_nested_filled) - if whole_match_value: - return get_address_value(whole_match_value) - templ_values_filled = pattern_value.sub( - lambda x: str(get_address_value(x)), - templ_nested_filled - ) - whole_match_list = self._pattern_list.fullmatch(templ_values_filled.strip()) - if whole_match_list: - return get_address_value(whole_match_list, return_all_matches=True) - whole_match_unpack = self._pattern_unpack.fullmatch(templ_values_filled.strip()) - if whole_match_unpack: - submatch_list = self._pattern_list.fullmatch(whole_match_unpack.group(1).strip()) - if submatch_list: - return get_address_value(submatch_list, return_all_matches=True) - submatch_code = self._pattern_code.fullmatch(whole_match_unpack.group(1).strip()) + match_value = pattern_value.fullmatch(templ) + if match_value: + return get_address_value(fill_nested_values(match_value)) + # Handle list blocks + match_list = self._pattern_list.fullmatch(templ) + if match_list: + return get_address_value(fill_nested_values(match_list), return_all_matches=True) + # Handle code blocks + match_code = self._pattern_code.fullmatch(templ) + if match_code: + return get_code_value(match_code) + # Handle unpack blocks + match_unpack = self._pattern_unpack.fullmatch(templ) + if match_unpack: + unpack_value = match_unpack.group(1) + submatch_code = self._pattern_code.fullmatch(unpack_value) if submatch_code: - return get_code_value(submatch_code.group(1)) - return get_address_value(whole_match_unpack) - whole_match_code = self._pattern_code.fullmatch(templ_values_filled.strip()) - if whole_match_code: - templ_list_filled = self._pattern_list.sub( - lambda x: str(get_address_value(x, return_all_matches=True)), - whole_match_code.group(1) - ) - return get_code_value(templ_list_filled) - unpacked_filled = self._pattern_unpack.sub(string_filler_unpack, templ_values_filled) - return self._pattern_code.sub( - lambda x: self._stringer(get_code_value(x.group(1))), + return get_code_value(submatch_code) + unpack_value = fill_nested_values(unpack_value) + submatch_list = self._pattern_list.fullmatch(unpack_value) + if submatch_list: + return get_address_value(unpack_value, return_all_matches=True) + return get_address_value(unpack_value) + # Handle strings + code_blocks_filled = self._pattern_code.sub( + lambda x: self._stringer(get_code_value(x)), + templ + ) + nested_values_filled = fill_nested_values(code_blocks_filled) + unpacked_filled = self._pattern_unpack.sub(string_filler_unpack, nested_values_filled) + lists_filled = self._pattern_list.sub( + lambda x: self._stringer(get_address_value(x)), unpacked_filled ) + templ_values_filled = pattern_value.sub( + lambda x: self._stringer(get_address_value(x)), + lists_filled + ) + return templ_values_filled if isinstance(templ, list): out = [] @@ -411,8 +403,9 @@ def string_filler_unpack(match: _re.Match): current_path=new_path, relative_path_anchor=get_relative_path(new_path), level=0, + current_chain=current_chain + [new_path], ) - if isinstance(elem, str) and self._pattern_unpack.fullmatch(elem.strip()): + if isinstance(elem, str) and self._pattern_unpack.fullmatch(elem): out.extend(elem_filled) else: out.append(elem_filled) @@ -426,9 +419,9 @@ def string_filler_unpack(match: _re.Match): current_path=current_path, relative_path_anchor=relative_path_anchor, level=0, - internal=True, + current_chain=current_chain, ) - if isinstance(key, str) and self._pattern_unpack.fullmatch(key.strip()): + if isinstance(key, str) and self._pattern_unpack.fullmatch(key): new_dict.update(key_filled) continue if key_filled in self._template_keys: @@ -440,20 +433,32 @@ def string_filler_unpack(match: _re.Match): current_path=new_path, relative_path_anchor=get_relative_path(new_path), level=0, + current_chain=current_chain + [new_path], ) return new_dict return templ - def _find_loop(self): - for pattern_length in range(1, len(self._path_history) // 2 + 1): - # Slice the end of the list into two consecutive patterns - pattern = self._path_history[-pattern_length:] - previous_pattern = self._path_history[-2 * pattern_length:-pattern_length] - # Check if the two patterns are the same - if pattern == previous_pattern: - pattern.insert(0, pattern[-1]) - return pattern - return + # def _find_loop(self): + # for pattern_length in range(1, len(self._path_history) // 2 + 1): + # # Slice the end of the list into two consecutive patterns + # pattern = self._path_history[-pattern_length:] + # previous_pattern = self._path_history[-2 * pattern_length:-pattern_length] + # # Check if the two patterns are the same + # if pattern == previous_pattern: + # pattern.insert(0, pattern[-1]) + # return pattern + # return + + def _get_value_regex_pattern(self, level: int = 0) -> _RegexPattern: + if level in self._pattern_value: + return self._pattern_value[level] + count = self._repeater_count_value + level + pattern = _RegexPattern( + start=f"{self._marker_start_value}{self._repeater_start_value * count} ", + end=f" {self._repeater_end_value * count}{self._marker_end_value}", + ) + self._pattern_value[level] = pattern + return pattern @staticmethod def _remove_leading_periods(s: str) -> (str, int): @@ -481,6 +486,10 @@ def _recursive_extract(expr): _recursive_extract(jsonpath) return fields + def _concat_json_paths(self, path1, path2): + if not isinstance(path2, _jsonpath.Child): + return _jsonpath.Child(path1, path2) + return _jsonpath.Child(self._concat_json_paths(path1, path2.left), path2.right) class _RegexPattern: @@ -495,7 +504,7 @@ def fullmatch(self, string: str) -> _re.Match | None: matches = self.pattern.findall(string) if len(matches) == 1: # Verify the match spans the entire string - return self.pattern.fullmatch(string) + return self.pattern.fullmatch(string.strip()) return None def sub(self, repl, string: str) -> str: