diff --git a/pyproject.toml b/pyproject.toml index b10774a25..e3a5a81f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "h5py>=3.6.0", "xarray>=0.20.2", "PyYAML>=6.0", - "numpy>=1.21.2", + "numpy>=1.21.2,<2.0.0", "pandas>=1.3.2", "ase>=3.19.0", "mergedeep", diff --git a/src/pynxtools/dataconverter/nexus_tree.py b/src/pynxtools/dataconverter/nexus_tree.py index de7f11da4..588227a1e 100644 --- a/src/pynxtools/dataconverter/nexus_tree.py +++ b/src/pynxtools/dataconverter/nexus_tree.py @@ -28,6 +28,7 @@ It also allows for adding further nodes from the inheritance chain on the fly. """ +from functools import reduce from typing import Any, List, Literal, Optional, Set, Tuple, Union import lxml.etree as ET @@ -41,6 +42,7 @@ is_appdef, remove_namespace_from_tag, ) +from pynxtools.definitions.dev_tools.utils.nxdl_utils import get_nx_namefit NexusType = Literal[ "NX_BINARY", @@ -139,15 +141,28 @@ class NexusNode(NodeMixin): optionality: Literal["required", "recommended", "optional"] = "required" variadic: bool = False inheritance: List[ET._Element] + is_a: List["NexusNode"] + parent_of: List["NexusNode"] def _set_optionality(self): + """ + Sets the optionality of the current node + if `recommended`, `required` or `optional` is set. + Also sets the field to optional if `maxOccurs == 0` or to required + if `maxOccurs > 0`. + """ if not self.inheritance: return if self.inheritance[0].attrib.get("recommended"): self.optionality = "recommended" - elif ( - self.inheritance[0].attrib.get("optional") - or self.inheritance[0].attrib.get("minOccurs") == "0" + elif self.inheritance[0].attrib.get("required") or ( + isinstance(self, NexusGroup) + and self.occurrence_limits[0] is not None + and self.occurrence_limits[0] > 0 + ): + self.optionality = "required" + elif self.inheritance[0].attrib.get("optional") or ( + isinstance(self, NexusGroup) and self.occurrence_limits[0] == 0 ): self.optionality = "optional" @@ -172,8 +187,13 @@ def __init__( else: self.inheritance = [] self.parent = parent + self.is_a = [] + self.parent_of = [] def _construct_inheritance_chain_from_parent(self): + """ + Builds the inheritance chain of the current node based on the parent node. + """ if self.parent is None: return for xml_elem in self.parent.inheritance: @@ -221,18 +241,33 @@ def search_child_with_name( direct_child = next((x for x in self.children if x.name == name), None) if direct_child is not None: return direct_child - if name in self.get_all_children_names(): + if name in self.get_all_direct_children_names(): return self.add_inherited_node(name) return None - def get_all_children_names( - self, depth: Optional[int] = None, only_appdef: bool = False + def get_all_direct_children_names( + self, + node_type: Optional[str] = None, + nx_class: Optional[str] = None, + depth: Optional[int] = None, + only_appdef: bool = False, ) -> Set[str]: """ Get all children names of the current node up to a certain depth. Only `field`, `group` `choice` or `attribute` are considered as children. Args: + node_type (Optional[str], optional): + The tags of the children to consider. + This should either be "field", "group", "choice" or "attribute". + If None all tags are considered. + Defaults to None. + nx_class (Optional[str], optional): + The NeXus class of the group to consider. + This is only used if `node_type` is "group". + It should contain the preceding `NX` and the class name in lowercase, + e.g., "NXentry". + Defaults to None. depth (Optional[int], optional): The inheritance depth up to which get children names. `depth=1` will return only the children of the current node. @@ -251,18 +286,24 @@ def get_all_children_names( if depth is not None and (not isinstance(depth, int) or depth < 0): raise ValueError("Depth must be a positive integer or None") + tag_type = "" + if node_type == "group" and nx_class is not None: + tag_type = f"[@type='{nx_class}']" + + if node_type is not None: + search_tags = f"nx:{node_type}{tag_type}" + else: + search_tags = ( + "*[self::nx:field or self::nx:group " + "or self::nx:attribute or self::nx:choice]" + ) + names = set() for elem in self.inheritance[:depth]: if only_appdef and not is_appdef(elem): break - for subelems in elem.xpath( - ( - r"*[self::nx:field or self::nx:group " - r"or self::nx:attribute or self::nx:choice]" - ), - namespaces=namespaces, - ): + for subelems in elem.xpath(search_tags, namespaces=namespaces): if "name" in subelems.attrib: names.add(subelems.attrib["name"]) elif "type" in subelems.attrib: @@ -351,15 +392,54 @@ def get_docstring(self, depth: Optional[int] = None) -> List[str]: return docstrings def _build_inheritance_chain(self, xml_elem: ET._Element) -> List[ET._Element]: + """ + Builds the inheritance chain based on the given xml node and the inheritance + chain of this node. + + Args: + xml_elem (ET._Element): The xml element to build the inheritance chain for. + + Returns: + List[ET._Element]: + The list of xml nodes representing the inheritance chain. + This represents the direct field or group inside the specific xml file. + """ name = xml_elem.attrib.get("name") inheritance_chain = [xml_elem] for elem in self.inheritance: inherited_elem = elem.xpath( f"nx:group[@type='{xml_elem.attrib['type']}' and @name='{name}']" if name is not None - else f"nx:group[@type='{xml_elem.attrib['type']}']", + else f"nx:group[@type='{xml_elem.attrib['type']}' and not(@name)]", namespaces=namespaces, ) + if not inherited_elem and name is not None: + # Try to namefit + groups = elem.findall( + f"nx:group[@type='{xml_elem.attrib['type']}']", + namespaces=namespaces, + ) + best_group = None + best_score = -1 + for group in groups: + if name in group.attrib and not contains_uppercase( + group.attrib["name"] + ): + continue + group_name = ( + group.attrib.get("name") + if "name" in group.attrib + else group.attrib["type"][2:].upper() + ) + + score = get_nx_namefit(name, group_name) + if get_nx_namefit(name, group_name) >= best_score: + best_group = group + best_score = score + + if best_group is not None: + inherited_elem = [best_group] + if inherited_elem and inherited_elem[0] not in inheritance_chain: inheritance_chain.append(inherited_elem[0]) bc_xml_root, _ = get_nxdl_root_and_path(xml_elem.attrib["type"]) @@ -432,18 +512,19 @@ def add_inherited_node(self, name: str) -> Optional["NexusNode"]: """ for elem in self.inheritance: xml_elem = elem.xpath( - f"*[self::nx:field or self::nx:group or self::nx:attribute][@name='{name}']", + "*[self::nx:field or self::nx:group or" + f" self::nx:attribute or self::nx:choice][@name='{name}']", namespaces=namespaces, ) if not xml_elem: # Find group by naming convention xml_elem = elem.xpath( - f"*[self::nx:group][@type='NX{name.lower()}']", + "*[self::nx:group or self::nx:choice]" + f"[@type='NX{name.lower()}' and not(@name)]", namespaces=namespaces, ) if xml_elem: - new_node = self.add_node_from(xml_elem[0]) - return new_node + return self.add_node_from(xml_elem[0]) return None @@ -462,7 +543,7 @@ class NexusChoice(NexusNode): type: Literal["choice"] = "choice" def __init__(self, **data) -> None: - super().__init__(**data) + super().__init__(type=self.type, **data) self._construct_inheritance_chain_from_parent() self._set_optionality() @@ -489,7 +570,54 @@ class NexusGroup(NexusNode): Optional[int], ] = (None, None) + def _check_sibling_namefit(self): + """ + Namefits siblings at the current tree level if they are not part of the same + appdef or base class. + The function fills the `parent_of` property of this node and the `is_a` property + of the connected nodes to represent the relation. + It also adapts the optionality if enough required children are present. + """ + if not self.variadic: + return + + for sibling in self.parent.get_all_direct_children_names( + node_type=self.type, nx_class=self.nx_class + ): + if sibling == self.name or contains_uppercase(sibling): + continue + if sibling.lower() == self.name.lower(): + continue + + if get_nx_namefit(sibling, self.name) >= -1: + fit = self.parent.search_child_with_name(sibling) + if ( + self.inheritance[0] != fit.inheritance[0] + and self.inheritance[0] in fit.inheritance + ): + fit.is_a.append(self) + self.parent_of.append(fit) + + min_occurs = ( + 0 if self.occurrence_limits[0] is None else self.occurrence_limits[0] + ) + min_occurs = 1 if self.optionality == "required" else min_occurs + + required_children = reduce( + lambda x, y: x + (1 if y.optionality == "required" else 0), + self.parent_of, + 0, + ) + + if required_children >= min_occurs: + self.optionality = "optional" + def _set_occurence_limits(self): + """ + Sets the occurence limits of the current group. + Searches the inheritance chain until a value is found. + Otherwise, the occurence_limits are set to (None, None). + """ if not self.inheritance: return xml_elem = self.inheritance[0] @@ -511,6 +639,7 @@ def __init__(self, nx_class: str, **data) -> None: self.nx_class = nx_class self._set_occurence_limits() self._set_optionality() + self._check_sibling_namefit() def __repr__(self) -> str: return ( @@ -561,18 +690,31 @@ class NexusEntity(NexusNode): shape: Optional[Tuple[Optional[int], ...]] = None def _set_type(self): + """ + Sets the dtype of the current entity based on the values in the inheritance chain. + The first vale found is used. + """ for elem in self.inheritance: if "type" in elem.attrib: self.dtype = elem.attrib["type"] return def _set_unit(self): + """ + Sets the unit of the current entity based on the values in the inheritance chain. + The first vale found is used. + """ for elem in self.inheritance: if "units" in elem.attrib: self.unit = elem.attrib["units"] return def _set_items(self): + """ + Sets the enumeration items of the current entity + based on the values in the inheritance chain. + The first vale found is used. + """ if not self.dtype == "NX_CHAR": return for elem in self.inheritance: @@ -584,6 +726,10 @@ def _set_items(self): return def _set_shape(self): + """ + Sets the shape of the current entity based on the values in the inheritance chain. + The first vale found is used. + """ for elem in self.inheritance: dimension = elem.find(f"nx:dimensions", namespaces=namespaces) if dimension is not None: @@ -638,7 +784,7 @@ def populate_tree_from_parents(node: NexusNode): node (NexusNode): The current node from which to populate the tree. """ - for child in node.get_all_children_names(only_appdef=True): + for child in node.get_all_direct_children_names(only_appdef=True): child_node = node.search_child_with_name(child) populate_tree_from_parents(child_node) diff --git a/src/pynxtools/dataconverter/validation.py b/src/pynxtools/dataconverter/validation.py index 5c30a8104..2f100a9d7 100644 --- a/src/pynxtools/dataconverter/validation.py +++ b/src/pynxtools/dataconverter/validation.py @@ -206,7 +206,7 @@ def get_variations_of(node: NexusNode, keys: Mapping[str, Any]) -> List[str]: continue if ( get_nx_namefit(name2fit, node.name) >= 0 - and key not in node.parent.get_all_children_names() + and key not in node.parent.get_all_direct_children_names() ): variations.append(key) if nx_name is not None and not variations: @@ -239,7 +239,7 @@ def check_nxdata(): data_node = node.search_child_with_name((signal, "DATA")) data_bc_node = node.search_child_with_name("DATA") data_node.inheritance.append(data_bc_node.inheritance[0]) - for child in data_node.get_all_children_names(): + for child in data_node.get_all_direct_children_names(): data_node.search_child_with_name(child) handle_field( @@ -271,7 +271,7 @@ def check_nxdata(): axis_node = node.search_child_with_name((axis, "AXISNAME")) axis_bc_node = node.search_child_with_name("AXISNAME") axis_node.inheritance.append(axis_bc_node.inheritance[0]) - for child in axis_node.get_all_children_names(): + for child in axis_node.get_all_direct_children_names(): axis_node.search_child_with_name(child) handle_field( @@ -328,13 +328,24 @@ def check_nxdata(): def handle_group(node: NexusGroup, keys: Mapping[str, Any], prev_path: str): variants = get_variations_of(node, keys) - if not variants: - if node.optionality == "required" and node.type in missing_type_err: - collector.collect_and_log( - f"{prev_path}/{node.name}", missing_type_err.get(node.type), None - ) + if node.parent_of: + for child in node.parent_of: + variants += get_variations_of(child, keys) + if ( + not variants + and node.optionality == "required" + and node.type in missing_type_err + ): + collector.collect_and_log( + f"{prev_path}/{node.name}", + missing_type_err.get(node.type), + None, + ) return for variant in variants: + if variant in [node.name for node in node.parent_of]: + # Don't process if this is actually a sub-variant of this group + continue nx_class, _ = split_class_and_name_of(variant) if not isinstance(keys[variant], Mapping): if nx_class is not None: @@ -499,16 +510,12 @@ def is_documented(key: str, node: NexusNode) -> bool: return True for name in key[1:].replace("@", "").split("/"): - children = node.get_all_children_names() + children = node.get_all_direct_children_names() best_name = best_namefit_of(name, children) if best_name is None: return False - resolver = Resolver("name", relax=True) - child_node = resolver.get(node, best_name) - node = ( - node.add_inherited_node(best_name) if child_node is None else child_node - ) + node = node.search_child_with_name(best_name) if isinstance(mapping[key], dict) and "link" in mapping[key]: # TODO: Follow link and check consistency with current field @@ -612,7 +619,7 @@ def populate_full_tree(node: NexusNode, max_depth: Optional[int] = 5, depth: int # but it does while recursing the tree and it should # be fixed. return - for child in node.get_all_children_names(): + for child in node.get_all_direct_children_names(): child_node = node.search_child_with_name(child) populate_full_tree(child_node, max_depth=max_depth, depth=depth + 1) diff --git a/tests/dataconverter/test_nexus_tree.py b/tests/dataconverter/test_nexus_tree.py index e3724b3ff..67cf33c44 100644 --- a/tests/dataconverter/test_nexus_tree.py +++ b/tests/dataconverter/test_nexus_tree.py @@ -42,7 +42,8 @@ def test_correct_extension_of_tree(): def get_node_fields(tree: NexusNode) -> List[Tuple[str, Any]]: return list( filter( - lambda x: not x[0].startswith("_") and x[0] not in "inheritance", + lambda x: not x[0].startswith("_") + and x[0] not in ("inheritance", "is_a", "parent_of"), tree.__dict__.items(), ) ) diff --git a/tests/dataconverter/test_validation.py b/tests/dataconverter/test_validation.py index cb68b60e8..135d0476c 100644 --- a/tests/dataconverter/test_validation.py +++ b/tests/dataconverter/test_validation.py @@ -21,46 +21,49 @@ import numpy as np import pytest - from pynxtools.dataconverter.validation import validate_dict_against def get_data_dict(): return { - "/my_entry/optional_parent/required_child": 1, - "/my_entry/optional_parent/optional_child": 1, - "/my_entry/nxodd_name/float_value": 2.0, - "/my_entry/nxodd_name/float_value/@units": "nm", - "/my_entry/nxodd_name/bool_value": True, - "/my_entry/nxodd_name/bool_value/@units": "", - "/my_entry/nxodd_name/int_value": 2, - "/my_entry/nxodd_name/int_value/@units": "eV", - "/my_entry/nxodd_name/posint_value": np.array([1, 2, 3], dtype=np.int8), - "/my_entry/nxodd_name/posint_value/@units": "kg", - "/my_entry/nxodd_name/char_value": "just chars", - "/my_entry/nxodd_name/char_value/@units": "", - "/my_entry/nxodd_name/type": "2nd type", - "/my_entry/nxodd_name/date_value": "2022-01-22T12:14:12.05018+00:00", - "/my_entry/nxodd_name/date_value/@units": "", - "/my_entry/nxodd_two_name/bool_value": True, - "/my_entry/nxodd_two_name/bool_value/@units": "", - "/my_entry/nxodd_two_name/int_value": 2, - "/my_entry/nxodd_two_name/int_value/@units": "eV", - "/my_entry/nxodd_two_name/posint_value": np.array([1, 2, 3], dtype=np.int8), - "/my_entry/nxodd_two_name/posint_value/@units": "kg", - "/my_entry/nxodd_two_name/char_value": "just chars", - "/my_entry/nxodd_two_name/char_value/@units": "", - "/my_entry/nxodd_two_name/type": "2nd type", - "/my_entry/nxodd_two_name/date_value": "2022-01-22T12:14:12.05018+00:00", - "/my_entry/nxodd_two_name/date_value/@units": "", - "/my_entry/my_group/required_field": 1, - "/my_entry/definition": "NXtest", - "/my_entry/definition/@version": "2.4.6", - "/my_entry/program_name": "Testing program", - "/my_entry/my_group/optional_field": 1, - "/my_entry/required_group/description": "An example description", - "/my_entry/required_group2/description": "An example description", - "/my_entry/optional_parent/req_group_in_opt_group/data": 1, + "/ENTRY[my_entry]/optional_parent/required_child": 1, + "/ENTRY[my_entry]/optional_parent/optional_child": 1, + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/float_value": 2.0, + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/float_value/@units": "nm", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/bool_value": True, + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/bool_value/@units": "", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/int_value": 2, + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/int_value/@units": "eV", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/posint_value": np.array( + [1, 2, 3], dtype=np.int8 + ), + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/posint_value/@units": "kg", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/char_value": "just chars", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/char_value/@units": "", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/type": "2nd type", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/date_value": "2022-01-22T12:14:12.05018+00:00", + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/date_value/@units": "", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/bool_value": True, + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/bool_value/@units": "", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/int_value": 2, + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/int_value/@units": "eV", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/posint_value": np.array( + [1, 2, 3], dtype=np.int8 + ), + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/posint_value/@units": "kg", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/char_value": "just chars", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/char_value/@units": "", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/type": "2nd type", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/date_value": "2022-01-22T12:14:12.05018+00:00", + "/ENTRY[my_entry]/NXODD_name[nxodd_two_name]/date_value/@units": "", + "/ENTRY[my_entry]/OPTIONAL_group[my_group]/required_field": 1, + "/ENTRY[my_entry]/definition": "NXtest", + "/ENTRY[my_entry]/definition/@version": "2.4.6", + "/ENTRY[my_entry]/program_name": "Testing program", + "/ENTRY[my_entry]/OPTIONAL_group[my_group]/optional_field": 1, + "/ENTRY[my_entry]/required_group/description": "An example description", + "/ENTRY[my_entry]/required_group2/description": "An example description", + "/ENTRY[my_entry]/optional_parent/req_group_in_opt_group/data": 1, "/@default": "Some NXroot attribute", } @@ -86,7 +89,9 @@ def alter_dict(new_values: Dict[str, Any], data_dict: Dict[str, Any]) -> Dict[st [ pytest.param(get_data_dict(), id="valid-unaltered-data-dict"), pytest.param( - remove_from_dict("/my_entry/nxodd_name/float_value", get_data_dict()), + remove_from_dict( + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/float_value", get_data_dict() + ), id="removed-optional-value", ), ], @@ -101,8 +106,10 @@ def test_valid_data_dict(caplog, data_dict): "data_dict, error_message", [ pytest.param( - remove_from_dict("/my_entry/nxodd_name/bool_value", get_data_dict()), - "The data entry corresponding to /my_entry/nxodd_name/bool_value is required and hasn't been supplied by the reader.", + remove_from_dict( + "/ENTRY[my_entry]/NXODD_name[nxodd_name]/bool_value", get_data_dict() + ), + "The data entry corresponding to /ENTRY[my_entry]/NXODD_name[nxodd_name]/bool_value is required and hasn't been supplied by the reader.", id="missing-required-value", ) ],