Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation fixes #350

Merged
merged 9 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"h5py>=3.6.0",
"xarray>=0.20.2",
"PyYAML>=6.0",
"numpy>=1.21.2",
"numpy>=1.21.2,<2.0.0",
"pandas>=1.3.2",
"ase>=3.19.0",
"mergedeep",
Expand Down
186 changes: 166 additions & 20 deletions src/pynxtools/dataconverter/nexus_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
It also allows for adding further nodes from the inheritance chain on the fly.
"""

from functools import reduce
from typing import Any, List, Literal, Optional, Set, Tuple, Union

import lxml.etree as ET
Expand All @@ -41,6 +42,7 @@
is_appdef,
remove_namespace_from_tag,
)
from pynxtools.definitions.dev_tools.utils.nxdl_utils import get_nx_namefit

NexusType = Literal[
"NX_BINARY",
Expand Down Expand Up @@ -139,15 +141,28 @@ class NexusNode(NodeMixin):
optionality: Literal["required", "recommended", "optional"] = "required"
variadic: bool = False
inheritance: List[ET._Element]
is_a: List["NexusNode"]
parent_of: List["NexusNode"]

def _set_optionality(self):
"""
Sets the optionality of the current node
if `recommended`, `required` or `optional` is set.
Also sets the field to optional if `maxOccurs == 0` or to required
if `maxOccurs > 0`.
"""
if not self.inheritance:
return
if self.inheritance[0].attrib.get("recommended"):
self.optionality = "recommended"
elif (
self.inheritance[0].attrib.get("optional")
or self.inheritance[0].attrib.get("minOccurs") == "0"
elif self.inheritance[0].attrib.get("required") or (
isinstance(self, NexusGroup)
and self.occurrence_limits[0] is not None
and self.occurrence_limits[0] > 0
):
self.optionality = "required"
elif self.inheritance[0].attrib.get("optional") or (
isinstance(self, NexusGroup) and self.occurrence_limits[0] == 0
):
self.optionality = "optional"

Expand All @@ -172,8 +187,13 @@ def __init__(
else:
self.inheritance = []
self.parent = parent
self.is_a = []
self.parent_of = []

def _construct_inheritance_chain_from_parent(self):
"""
Builds the inheritance chain of the current node based on the parent node.
"""
if self.parent is None:
return
for xml_elem in self.parent.inheritance:
Expand Down Expand Up @@ -221,18 +241,33 @@ def search_child_with_name(
direct_child = next((x for x in self.children if x.name == name), None)
if direct_child is not None:
return direct_child
if name in self.get_all_children_names():
if name in self.get_all_direct_children_names():
return self.add_inherited_node(name)
return None

def get_all_children_names(
self, depth: Optional[int] = None, only_appdef: bool = False
def get_all_direct_children_names(
self,
node_type: Optional[str] = None,
nx_class: Optional[str] = None,
depth: Optional[int] = None,
only_appdef: bool = False,
) -> Set[str]:
"""
Get all children names of the current node up to a certain depth.
Only `field`, `group` `choice` or `attribute` are considered as children.

Args:
node_type (Optional[str], optional):
The tags of the children to consider.
This should either be "field", "group", "choice" or "attribute".
If None all tags are considered.
Defaults to None.
nx_class (Optional[str], optional):
The NeXus class of the group to consider.
This is only used if `node_type` is "group".
It should contain the preceding `NX` and the class name in lowercase,
e.g., "NXentry".
Defaults to None.
depth (Optional[int], optional):
The inheritance depth up to which get children names.
`depth=1` will return only the children of the current node.
Expand All @@ -251,18 +286,24 @@ def get_all_children_names(
if depth is not None and (not isinstance(depth, int) or depth < 0):
raise ValueError("Depth must be a positive integer or None")

tag_type = ""
if node_type == "group" and nx_class is not None:
tag_type = f"[@type='{nx_class}']"

if node_type is not None:
search_tags = f"nx:{node_type}{tag_type}"
else:
search_tags = (
"*[self::nx:field or self::nx:group "
"or self::nx:attribute or self::nx:choice]"
)

names = set()
for elem in self.inheritance[:depth]:
if only_appdef and not is_appdef(elem):
break

for subelems in elem.xpath(
(
r"*[self::nx:field or self::nx:group "
r"or self::nx:attribute or self::nx:choice]"
),
namespaces=namespaces,
):
for subelems in elem.xpath(search_tags, namespaces=namespaces):
if "name" in subelems.attrib:
names.add(subelems.attrib["name"])
elif "type" in subelems.attrib:
Expand Down Expand Up @@ -351,15 +392,54 @@ def get_docstring(self, depth: Optional[int] = None) -> List[str]:
return docstrings

def _build_inheritance_chain(self, xml_elem: ET._Element) -> List[ET._Element]:
"""
Builds the inheritance chain based on the given xml node and the inheritance
chain of this node.

Args:
xml_elem (ET._Element): The xml element to build the inheritance chain for.

Returns:
List[ET._Element]:
The list of xml nodes representing the inheritance chain.
This represents the direct field or group inside the specific xml file.
"""
name = xml_elem.attrib.get("name")
inheritance_chain = [xml_elem]
for elem in self.inheritance:
inherited_elem = elem.xpath(
f"nx:group[@type='{xml_elem.attrib['type']}' and @name='{name}']"
if name is not None
else f"nx:group[@type='{xml_elem.attrib['type']}']",
else f"nx:group[@type='{xml_elem.attrib['type']}' and not(@name)]",
namespaces=namespaces,
)
if not inherited_elem and name is not None:
# Try to namefit
groups = elem.findall(
f"nx:group[@type='{xml_elem.attrib['type']}']",
namespaces=namespaces,
)
best_group = None
best_score = -1
for group in groups:
if name in group.attrib and not contains_uppercase(
group.attrib["name"]
):
continue
group_name = (
group.attrib.get("name")
if "name" in group.attrib
else group.attrib["type"][2:].upper()
)

score = get_nx_namefit(name, group_name)
if get_nx_namefit(name, group_name) >= best_score:
best_group = group
best_score = score

if best_group is not None:
inherited_elem = [best_group]

if inherited_elem and inherited_elem[0] not in inheritance_chain:
inheritance_chain.append(inherited_elem[0])
bc_xml_root, _ = get_nxdl_root_and_path(xml_elem.attrib["type"])
Expand Down Expand Up @@ -432,18 +512,19 @@ def add_inherited_node(self, name: str) -> Optional["NexusNode"]:
"""
for elem in self.inheritance:
xml_elem = elem.xpath(
f"*[self::nx:field or self::nx:group or self::nx:attribute][@name='{name}']",
"*[self::nx:field or self::nx:group or"
f" self::nx:attribute or self::nx:choice][@name='{name}']",
namespaces=namespaces,
)
if not xml_elem:
# Find group by naming convention
xml_elem = elem.xpath(
f"*[self::nx:group][@type='NX{name.lower()}']",
"*[self::nx:group or self::nx:choice]"
f"[@type='NX{name.lower()}' and not(@name)]",
namespaces=namespaces,
)
if xml_elem:
new_node = self.add_node_from(xml_elem[0])
return new_node
return self.add_node_from(xml_elem[0])
return None


Expand All @@ -462,7 +543,7 @@ class NexusChoice(NexusNode):
type: Literal["choice"] = "choice"

def __init__(self, **data) -> None:
super().__init__(**data)
super().__init__(type=self.type, **data)
self._construct_inheritance_chain_from_parent()
self._set_optionality()

Expand All @@ -489,7 +570,54 @@ class NexusGroup(NexusNode):
Optional[int],
] = (None, None)

def _check_sibling_namefit(self):
"""
Namefits siblings at the current tree level if they are not part of the same
appdef or base class.
The function fills the `parent_of` property of this node and the `is_a` property
of the connected nodes to represent the relation.
It also adapts the optionality if enough required children are present.
"""
if not self.variadic:
return

for sibling in self.parent.get_all_direct_children_names(
node_type=self.type, nx_class=self.nx_class
):
if sibling == self.name or contains_uppercase(sibling):
continue
if sibling.lower() == self.name.lower():
continue

if get_nx_namefit(sibling, self.name) >= -1:
fit = self.parent.search_child_with_name(sibling)
if (
self.inheritance[0] != fit.inheritance[0]
and self.inheritance[0] in fit.inheritance
):
fit.is_a.append(self)
self.parent_of.append(fit)

min_occurs = (
0 if self.occurrence_limits[0] is None else self.occurrence_limits[0]
)
min_occurs = 1 if self.optionality == "required" else min_occurs

required_children = reduce(
lambda x, y: x + (1 if y.optionality == "required" else 0),
self.parent_of,
0,
)

if required_children >= min_occurs:
self.optionality = "optional"

def _set_occurence_limits(self):
"""
Sets the occurence limits of the current group.
Searches the inheritance chain until a value is found.
Otherwise, the occurence_limits are set to (None, None).
"""
if not self.inheritance:
return
xml_elem = self.inheritance[0]
Expand All @@ -511,6 +639,7 @@ def __init__(self, nx_class: str, **data) -> None:
self.nx_class = nx_class
self._set_occurence_limits()
self._set_optionality()
self._check_sibling_namefit()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -561,18 +690,31 @@ class NexusEntity(NexusNode):
shape: Optional[Tuple[Optional[int], ...]] = None

def _set_type(self):
"""
Sets the dtype of the current entity based on the values in the inheritance chain.
The first vale found is used.
"""
for elem in self.inheritance:
if "type" in elem.attrib:
self.dtype = elem.attrib["type"]
return

def _set_unit(self):
"""
Sets the unit of the current entity based on the values in the inheritance chain.
The first vale found is used.
"""
for elem in self.inheritance:
if "units" in elem.attrib:
self.unit = elem.attrib["units"]
return

def _set_items(self):
"""
Sets the enumeration items of the current entity
based on the values in the inheritance chain.
The first vale found is used.
"""
if not self.dtype == "NX_CHAR":
return
for elem in self.inheritance:
Expand All @@ -584,6 +726,10 @@ def _set_items(self):
return

def _set_shape(self):
"""
Sets the shape of the current entity based on the values in the inheritance chain.
The first vale found is used.
"""
for elem in self.inheritance:
dimension = elem.find(f"nx:dimensions", namespaces=namespaces)
if dimension is not None:
Expand Down Expand Up @@ -638,7 +784,7 @@ def populate_tree_from_parents(node: NexusNode):
node (NexusNode):
The current node from which to populate the tree.
"""
for child in node.get_all_children_names(only_appdef=True):
for child in node.get_all_direct_children_names(only_appdef=True):
child_node = node.search_child_with_name(child)
populate_tree_from_parents(child_node)

Expand Down
Loading
Loading