Skip to content

Commit

Permalink
Support for appdef extends (#339)
Browse files Browse the repository at this point in the history
* Read extends keyword from file

* Insert extends parents into the inheritance chain

* Automatically populate tree from appdef parents

* Only populate tree if parents are present

* Docstring improvements

* Fix exact match in NX_CLASS[path] notation

* If minOccurs == 0, set the group to optional

* Add extended NXtest
  • Loading branch information
domna authored Jun 14, 2024
1 parent 3a7d63d commit 70d74b3
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 20 deletions.
23 changes: 23 additions & 0 deletions src/pynxtools/data/NXtest_extended.nxdl.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<?xml version="1.0" encoding="UTF-8"?>
<?xml-stylesheet type="text/xsl" href="nxdlformat.xsl" ?>
<definition xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://definition.nexusformat.org/nxdl/3.1 ../nxdl.xsd"
xmlns="http://definition.nexusformat.org/nxdl/3.1"
name="NXtest_extended"
extends="NXtest"
type="group"
category="application"
>
<doc>This is a dummy NXDL to test an extended application definition.</doc>
<group type="NXentry">
<field name="definition">
<doc>This is a dummy NXDL to test out the dataconverter.</doc>
<enumeration>
<item value="NXtest_extended"/>
</enumeration>
</field>
<field name="extended_field" type="NX_FLOAT" units="NX_ENERGY">
<doc>A dummy entry for an extended field.</doc>
</field>
</group>
</definition>
89 changes: 81 additions & 8 deletions src/pynxtools/dataconverter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,75 @@ def get_nxdl_name_from_elem(xml_element) -> str:
return name_to_add


def get_nxdl_name_for(xml_elem: ET._Element) -> Optional[str]:
"""
Get the name of the element from the NXDL element.
For an entity having a name this is just the name.
For groups it is the uppercase type without NX, e.g. "ENTRY" for "NXentry".
Args:
xml_elem (ET._Element): The xml element to get the name for.
Returns:
Optional[str]:
The name of the element.
None if the xml element has no name or type attribute.
"""
""""""
if "name" in xml_elem.attrib:
return xml_elem.attrib["name"]
if "type" in xml_elem.attrib:
return convert_nexus_to_caps(xml_elem.attrib["type"])
return None


def get_appdef_root(xml_elem: ET._Element) -> ET._Element:
"""
Get the root element of the tree of xml_elem
Args:
xml_elem (ET._Element): The element for which to get the root element.
Returns:
ET._Element: The root element of the tree.
"""
return xml_elem.getroottree().getroot()


def is_appdef(xml_elem: ET._Element) -> bool:
"""
Check whether the xml element is part of an application definition.
Args:
xml_elem (ET._Element): The xml_elem whose tree to check.
Returns:
bool: True if the xml_elem is part of an application definition.
"""
return get_appdef_root(xml_elem).attrib.get("category") == "application"


def get_all_parents_for(xml_elem: ET._Element) -> List[ET._Element]:
"""
Get all parents from the nxdl (via extends keyword)
Args:
xml_elem (ET._Element): The element to get the parents for.
Returns:
List[ET._Element]: The list of parents xml nodes.
"""
root = get_appdef_root(xml_elem)
inheritance_chain = []
extends = root.get("extends")
while extends is not None and extends != "NXobject":
parent_xml_root, _ = get_nxdl_root_and_path(extends)
extends = parent_xml_root.get("extends")
inheritance_chain.append(parent_xml_root)

return inheritance_chain


def get_nxdl_root_and_path(nxdl: str):
"""Get xml root element and file path from nxdl name e.g. NXapm.
Expand All @@ -213,16 +282,20 @@ def get_nxdl_root_and_path(nxdl: str):
FileNotFoundError
Error if no file with the given nxdl name is found.
"""

# Reading in the NXDL and generating a template
definitions_path = nexus.get_nexus_definitions_path()
if nxdl == "NXtest":
nxdl_f_path = os.path.join(
f"{os.path.abspath(os.path.dirname(__file__))}/../",
"data",
"NXtest.nxdl.xml",
)
elif nxdl == "NXroot":
nxdl_f_path = os.path.join(definitions_path, "base_classes", "NXroot.nxdl.xml")
data_path = os.path.join(
f"{os.path.abspath(os.path.dirname(__file__))}/../",
"data",
)
special_names = {
"NXtest": os.path.join(data_path, "NXtest.nxdl.xml"),
"NXtest_extended": os.path.join(data_path, "NXtest_extended.nxdl.xml"),
}

if nxdl in special_names:
nxdl_f_path = special_names[nxdl]
else:
nxdl_f_path = os.path.join(
definitions_path, "contributed_definitions", f"{nxdl}.nxdl.xml"
Expand Down
52 changes: 43 additions & 9 deletions src/pynxtools/dataconverter/nexus_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@

from pynxtools.dataconverter.helpers import (
contains_uppercase,
get_all_parents_for,
get_nxdl_name_for,
get_nxdl_root_and_path,
is_appdef,
remove_namespace_from_tag,
)

Expand Down Expand Up @@ -92,7 +95,7 @@
]

# This is the NeXus namespace for finding tags.
# It's updated from the nxdl file when `generate_tree_from`` is called.
# It's updated from the nxdl file when `generate_tree_from` is called.
namespaces = {"nx": "http://definition.nexusformat.org/nxdl/3.1"}


Expand All @@ -117,10 +120,6 @@ class NexusNode(NodeMixin):
This is set automatically on init and will be True if the name contains
any uppercase characets and False otherwise.
Defaults to False.
variadic_siblings (List[InstanceOf["NexusNode"]]):
Variadic siblings are names which are connected to each other, e.g.,
`AXISNAME` and `AXISNAME_indices` belong together and are variadic siblings.
Defaults to [].
inheritance (List[InstanceOf[ET._Element]]):
The inheritance chain of the node.
The first element of the list is the xml representation of this node.
Expand All @@ -146,7 +145,10 @@ def _set_optionality(self):
return
if self.inheritance[0].attrib.get("recommended"):
self.optionality = "recommended"
elif self.inheritance[0].attrib.get("optional"):
elif (
self.inheritance[0].attrib.get("optional")
or self.inheritance[0].attrib.get("minOccurs") == "0"
):
self.optionality = "optional"

def __init__(
Expand Down Expand Up @@ -223,7 +225,9 @@ def search_child_with_name(
return self.add_inherited_node(name)
return None

def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
def get_all_children_names(
self, depth: Optional[int] = None, only_appdef: bool = False
) -> Set[str]:
"""
Get all children names of the current node up to a certain depth.
Only `field`, `group` `choice` or `attribute` are considered as children.
Expand All @@ -234,6 +238,9 @@ def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
`depth=1` will return only the children of the current node.
`depth=None` will return all children names of all parents.
Defaults to None.
only_appdef (bool, optional):
Only considers appdef nodes as children.
Defaults to False.
Raises:
ValueError: If depth is not int or negativ.
Expand All @@ -246,6 +253,9 @@ def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:

names = set()
for elem in self.inheritance[:depth]:
if only_appdef and not is_appdef(elem):
break

for subelems in elem.xpath(
(
r"*[self::nx:field or self::nx:group "
Expand Down Expand Up @@ -354,6 +364,7 @@ def _build_inheritance_chain(self, xml_elem: ET._Element) -> List[ET._Element]:
inheritance_chain.append(inherited_elem[0])
bc_xml_root, _ = get_nxdl_root_and_path(xml_elem.attrib["type"])
inheritance_chain.append(bc_xml_root)
inheritance_chain += get_all_parents_for(bc_xml_root)

return inheritance_chain

Expand All @@ -371,13 +382,15 @@ def add_node_from(self, xml_elem: ET._Element) -> Optional["NexusNode"]:
The children node which was added.
None if the tag of the xml element is not known.
"""
default_optionality = "required" if is_appdef(xml_elem) else "optional"
tag = remove_namespace_from_tag(xml_elem.tag)
if tag in ("field", "attribute"):
name = xml_elem.attrib.get("name")
current_elem = NexusEntity(
parent=self,
name=name,
type=tag,
optionality=default_optionality,
)
elif tag == "group":
name = xml_elem.attrib.get("name", xml_elem.attrib["type"][2:].upper())
Expand All @@ -388,12 +401,14 @@ def add_node_from(self, xml_elem: ET._Element) -> Optional["NexusNode"]:
name=name,
nx_class=xml_elem.attrib["type"],
inheritance=inheritance_chain,
optionality=default_optionality,
)
elif tag == "choice":
current_elem = NexusChoice(
parent=self,
name=xml_elem.attrib["name"],
variadic=contains_uppercase(xml_elem.attrib["name"]),
optionality=default_optionality,
)
else:
# TODO: Tags: link
Expand Down Expand Up @@ -428,7 +443,6 @@ def add_inherited_node(self, name: str) -> Optional["NexusNode"]:
)
if xml_elem:
new_node = self.add_node_from(xml_elem[0])
new_node.optionality = "optional"
return new_node
return None

Expand Down Expand Up @@ -616,6 +630,19 @@ def __repr__(self) -> str:
return f"{self.name} ({self.optionality[:3]})"


def populate_tree_from_parents(node: NexusNode):
"""
Recursively populate the tree from the appdef parents (via extends keyword).
Args:
node (NexusNode):
The current node from which to populate the tree.
"""
for child in node.get_all_children_names(only_appdef=True):
child_node = node.search_child_with_name(child)
populate_tree_from_parents(child_node)


def generate_tree_from(appdef: str) -> NexusNode:
"""
Generates a NexusNode tree from an application definition.
Expand Down Expand Up @@ -655,14 +682,17 @@ def add_children_to(parent: NexusNode, xml_elem: ET._Element) -> None:
global namespaces
namespaces = {"nx": appdef_xml_root.nsmap[None]}

appdef_inheritance_chain = [appdef_xml_root]
appdef_inheritance_chain += get_all_parents_for(appdef_xml_root)

tree = NexusGroup(
name=appdef_xml_root.attrib["name"],
nx_class="NXroot",
type="group",
optionality="required",
variadic=False,
parent=None,
inheritance=[appdef_xml_root],
inheritance=appdef_inheritance_chain,
)
# Set root attributes
nx_root, _ = get_nxdl_root_and_path("NXroot")
Expand All @@ -673,4 +703,8 @@ def add_children_to(parent: NexusNode, xml_elem: ET._Element) -> None:
entry = appdef_xml_root.find("nx:group[@type='NXentry']", namespaces=namespaces)
add_children_to(tree, entry)

# Add all fields and attributes from the parent appdefs
if len(appdef_inheritance_chain) > 1:
populate_tree_from_parents(tree)

return tree
11 changes: 9 additions & 2 deletions src/pynxtools/dataconverter/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Collector,
ValidationProblem,
collector,
convert_nexus_to_caps,
is_valid_data_field,
)
from pynxtools.dataconverter.nexus_tree import (
Expand Down Expand Up @@ -186,8 +187,14 @@ def validate_dict_against(
"""

def get_variations_of(node: NexusNode, keys: Mapping[str, Any]) -> List[str]:
if not node.variadic and node.name in keys:
return [node.name]
if not node.variadic:
if node.name in keys:
return [node.name]
elif (
hasattr(node, "nx_class")
and f"{convert_nexus_to_caps(node.nx_class)}[{node.name}]" in keys
):
return [f"{convert_nexus_to_caps(node.nx_class)}[{node.name}]"]

variations = []
for key in keys:
Expand Down
50 changes: 49 additions & 1 deletion tests/dataconverter/test_nexus_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import get_args
from typing import Any, List, Tuple, get_args

from anytree import Resolver
from pynxtools.dataconverter.nexus_tree import (
NexusNode,
NexusType,
NexusUnitCategory,
generate_tree_from,
Expand Down Expand Up @@ -31,3 +33,49 @@ def test_if_all_types_are_present():
pydantic_literal_values = get_args(NexusType)

assert set(reference_types) == set(pydantic_literal_values)


def test_correct_extension_of_tree():
nxtest = generate_tree_from("NXtest")
nxtest_extended = generate_tree_from("NXtest_extended")

def get_node_fields(tree: NexusNode) -> List[Tuple[str, Any]]:
return list(
filter(
lambda x: not x[0].startswith("_") and x[0] not in "inheritance",
tree.__dict__.items(),
)
)

def left_tree_in_right_tree(left_tree, right_tree):
for left_child in left_tree.children:
if left_child.name not in map(lambda x: x.name, right_tree.children):
return False
right_child = list(
filter(lambda x: x.name == left_child.name, right_tree.children)
)[0]
if left_child.name == "definition":
# Definition should be overwritten
if not left_child.items == ["NXTEST", "NXtest"]:
return False
if not right_child.items == ["NXtest_extended"]:
return False
continue
for field in get_node_fields(left_child):
if field not in get_node_fields(right_child):
return False
if not left_tree_in_right_tree(left_child, right_child):
return False
return True

assert left_tree_in_right_tree(nxtest, nxtest_extended)

resolver = Resolver("name", relax=True)
extended_field = resolver.get(nxtest_extended, "ENTRY/extended_field")
assert extended_field is not None
assert extended_field.unit == "NX_ENERGY"
assert extended_field.dtype == "NX_FLOAT"
assert extended_field.optionality == "required"

nxtest_field = resolver.get(nxtest, "ENTRY/extended_field")
assert nxtest_field is None

0 comments on commit 70d74b3

Please sign in to comment.