From 9eab38c3ebed887b52a725fdf539feb92a8cbc1e Mon Sep 17 00:00:00 2001 From: domna Date: Tue, 21 May 2024 08:18:49 +0200 Subject: [PATCH] Improved get_children_names function --- src/pynxtools/dataconverter/nexus_tree.py | 40 ++++++++++++++---- src/pynxtools/dataconverter/validation.py | 51 ++++++++++++++--------- 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/src/pynxtools/dataconverter/nexus_tree.py b/src/pynxtools/dataconverter/nexus_tree.py index b21bbc6ba..2ad519bab 100644 --- a/src/pynxtools/dataconverter/nexus_tree.py +++ b/src/pynxtools/dataconverter/nexus_tree.py @@ -219,16 +219,32 @@ def search_child_with_name( direct_child = next((x for x in self.children if x.name == name), None) if direct_child is not None: return direct_child - if name in self.get_all_children_names(): + if name in self.get_all_direct_children_names(): return self.add_inherited_node(name) return None - def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]: + def get_all_direct_children_names( + self, + node_type: Optional[str] = None, + nx_class: Optional[str] = None, + depth: Optional[int] = None, + ) -> Set[str]: """ Get all children names of the current node up to a certain depth. Only `field`, `group` `choice` or `attribute` are considered as children. Args: + node_type (Optional[str], optional): + The tags of the children to consider. + This should either be "field", "group", "choice" or "attribute". + If None all tags are considered. + Defaults to None. + nx_class (Optional[str], optional): + The NeXus class of the group to consider. + This is only used if `node_type` is "group". + It should contain the preceding `NX` and the class name in lowercase, + e.g., "NXentry". + Defaults to None. depth (Optional[int], optional): The inheritance depth up to which get children names. `depth=1` will return only the children of the current node. @@ -244,15 +260,21 @@ def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]: if depth is not None and (not isinstance(depth, int) or depth < 0): raise ValueError("Depth must be a positive integer or None") + tag_type = "" + if node_type == "group" and nx_class is not None: + tag_type = f"[@type='{nx_class}']" + + if node_type is not None: + search_tags = f"*[self::nx:{node_type}{tag_type}]" + else: + search_tags = ( + r"*[self::nx:field or self::nx:group " + r"or self::nx:attribute or self::nx:choice]" + ) + names = set() for elem in self.inheritance[:depth]: - for subelems in elem.xpath( - ( - r"*[self::nx:field or self::nx:group " - r"or self::nx:attribute or self::nx:choice]" - ), - namespaces=namespaces, - ): + for subelems in elem.xpath(search_tags, namespaces=namespaces): if "name" in subelems.attrib: names.add(subelems.attrib["name"]) elif "type" in subelems.attrib: diff --git a/src/pynxtools/dataconverter/validation.py b/src/pynxtools/dataconverter/validation.py index dbba84af4..ba19bfb76 100644 --- a/src/pynxtools/dataconverter/validation.py +++ b/src/pynxtools/dataconverter/validation.py @@ -44,12 +44,20 @@ from pynxtools.definitions.dev_tools.utils.nxdl_utils import get_nx_namefit -def best_namefit_of_( - name: str, concepts: Set[str], nx_class: Optional[str] = None -) -> str: - # TODO: Find the best namefit of name in concepts - # Consider nx_class if it is not None - ... +def best_namefit_of_(name: str, concepts: Set[str]) -> str: + if not concepts: + return None + + if name in concepts: + return name + + best_match, score = max( + map(lambda x: (x, get_nx_namefit(name, x)), concepts), key=lambda x: x[1] + ) + if score < 0: + return None + + return best_match def validate_hdf_group_against(appdef: str, data: h5py.Group): @@ -64,9 +72,11 @@ def validate_hdf_group_against(appdef: str, data: h5py.Group): # Allow for 10000 cache entries. This should be enough for most cases @cached( cache=LRUCache(maxsize=10000), - key=lambda path, _: hashkey(path), + key=lambda path, *_: hashkey(path), ) - def find_node_for(path: str, nx_class: Optional[str] = None) -> Optional[NexusNode]: + def find_node_for( + path: str, node_type: Optional[str] = None, nx_class: Optional[str] = None + ) -> Optional[NexusNode]: if path == "": return tree @@ -75,10 +85,7 @@ def find_node_for(path: str, nx_class: Optional[str] = None) -> Optional[NexusNo best_child = best_namefit_of_( last_elem, - # TODO: Consider renaming `get_all_children_names` to - # `get_all_direct_children_names`. Because that's what it is. - node.get_all_children_names(), - nx_class, + node.get_all_direct_children_names(nx_class=nx_class, node_type=node_type), ) if best_child is None: return None @@ -92,7 +99,9 @@ def remove_from_req_fields(path: str): def handle_group(path: str, data: h5py.Group): node = find_node_for(path, data.attrs.get("NX_class")) if node is None: - # TODO: Log undocumented + collector.collect_and_log( + path, ValidationProblem.MissingDocumentation, None + ) return # TODO: Do actual group checks @@ -100,7 +109,9 @@ def handle_group(path: str, data: h5py.Group): def handle_field(path: str, data: h5py.Dataset): node = find_node_for(path) if node is None: - # TODO: Log undocumented + collector.collect_and_log( + path, ValidationProblem.MissingDocumentation, None + ) return remove_from_req_fields(f"{path}") @@ -110,7 +121,9 @@ def handle_attributes(path: str, attribute_names: h5py.AttributeManager): for attr_name in attribute_names: node = find_node_for(f"{path}/{attr_name}") if node is None: - # TODO: Log undocumented + collector.collect_and_log( + path, ValidationProblem.MissingDocumentation, None + ) continue remove_from_req_fields(f"{path}/@{attr_name}") @@ -282,7 +295,7 @@ def get_variations_of(node: NexusNode, keys: Mapping[str, Any]) -> List[str]: continue if ( get_nx_namefit(name2fit, node.name) >= 0 - and key not in node.parent.get_all_children_names() + and key not in node.parent.get_all_direct_children_names() ): variations.append(key) if nx_name is not None and not variations: @@ -315,7 +328,7 @@ def check_nxdata(): data_node = node.search_child_with_name((signal, "DATA")) data_bc_node = node.search_child_with_name("DATA") data_node.inheritance.append(data_bc_node.inheritance[0]) - for child in data_node.get_all_children_names(): + for child in data_node.get_all_direct_children_names(): data_node.search_child_with_name(child) handle_field( @@ -347,7 +360,7 @@ def check_nxdata(): axis_node = node.search_child_with_name((axis, "AXISNAME")) axis_bc_node = node.search_child_with_name("AXISNAME") axis_node.inheritance.append(axis_bc_node.inheritance[0]) - for child in axis_node.get_all_children_names(): + for child in axis_node.get_all_direct_children_names(): axis_node.search_child_with_name(child) handle_field( @@ -575,7 +588,7 @@ def is_documented(key: str, node: NexusNode) -> bool: return True for name in key[1:].replace("@", "").split("/"): - children = node.get_all_children_names() + children = node.get_all_direct_children_names() best_name = best_namefit_of(name, children) if best_name is None: return False