Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for conditional expressions on workflows. #666

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/sisl/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
lazy_init = self.context["lazy"]

if not lazy_init:
self.get()

Check warning

Code scanning / CodeQL

`__init__` method calls overridden method Warning

Call to self.
get
in __init__ method, which is overridden by
method Workflow.get
.

def __call__(self, *args, **kwargs):
self.update_inputs(*args, **kwargs)
Expand Down Expand Up @@ -362,9 +362,34 @@
def evaluate_input_node(node: Node):
return node.get()

def get(self):
def _get_evaluated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluates all inputs.

This function is ONLY called by the get method.

The default implementation just goes over the inputs and, if they are nodes,
makes them compute their output. But some nodes might need more complex things,
like only evaluating some inputs depending on the value of other inputs.

Parameters
----------
inputs : Dict[str, Any]
The input dictionary, possibly containing nodes to evaluate.
"""
# Map all inputs to their values. That is, if they are nodes, call the get
# method on them so that we get the updated output. This recursively evaluates nodes.
return self.map_inputs(
inputs=inputs,
func=self.evaluate_input_node,
only_nodes=True,
)

def get(self):
"""Returns the output of the node, possibly running the computation.

The computation of the node is only performed if the output is outdated,
otherwise this function just returns the stored output.
"""
self._logger.setLevel(getattr(logging, self.context["log_level"].upper()))

logs = logging.StreamHandler(StringIO())
Expand All @@ -375,11 +400,7 @@
self._logger.debug("Getting output from node...")
self._logger.debug(f"Raw inputs: {self._inputs}")

evaluated_inputs = self.map_inputs(
inputs=self._inputs,
func=self.evaluate_input_node,
only_nodes=True,
)
evaluated_inputs = self._get_evaluated_inputs(self._inputs)

self._logger.debug(f"Evaluated inputs: {evaluated_inputs}")

Expand Down Expand Up @@ -656,6 +677,10 @@
if not self.context["lazy"]:
self.get()

def get_diagram_label(self):
"""Returns the label to be used in diagrams when displaying this node."""
return None

Check warning on line 682 in src/sisl/nodes/node.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/node.py#L682

Added line #L682 was not covered by tests


class DummyInputValue(Node):
"""A dummy node that can be used as a placeholder for input values."""
Expand Down
113 changes: 113 additions & 0 deletions src/sisl/nodes/syntax_nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import operator
from typing import Any, Dict

from .node import Node


Expand All @@ -21,3 +24,113 @@
@staticmethod
def function(**items):
return items


class ConditionalExpressionSyntaxNode(SyntaxNode):
_outdate_due_to_inputs: bool = False

def _get_evaluated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluate the inputs of this node.

This function overwrites the default implementation in Node, because
we want to evaluate only the path that we are going to take.

Parameters
----------
inputs : dict
The inputs to this node.
"""

evaluated = {}

# Get the state of the test input, which determines the path that we are going to take.
evaluated["test"] = (
self.evaluate_input_node(inputs["test"])
if isinstance(inputs["test"], Node)
else inputs["test"]
)

# Evaluate only the path that we are going to take.
if evaluated["test"]:
evaluated["true"] = (
self.evaluate_input_node(inputs["true"])
if isinstance(inputs["true"], Node)
else inputs["true"]
)
evaluated["false"] = self._prev_evaluated_inputs.get("false")
else:
evaluated["false"] = (
self.evaluate_input_node(inputs["false"])
if isinstance(inputs["false"], Node)
else inputs["false"]
)
evaluated["true"] = self._prev_evaluated_inputs.get("true")

return evaluated

def update_inputs(self, **inputs):
# This is just a wrapper over the normal update_inputs, which makes
# sure that the node is only marked as outdated if the input that
# is being used has changed. Note that here we just create a flag,
# which is then used in _receive_outdated. (_receive_outdated is
# called by super().update_inputs())
current_test = self._prev_evaluated_inputs["test"]

self._outdate_due_to_inputs = len(inputs) > 0
if "test" not in inputs:
if current_test and ("true" not in inputs):
self._outdate_due_to_inputs = False

Check warning on line 82 in src/sisl/nodes/syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/syntax_nodes.py#L82

Added line #L82 was not covered by tests
elif not current_test and ("false" not in inputs):
self._outdate_due_to_inputs = False

try:
super().update_inputs(**inputs)
except:
self._outdate_due_to_inputs = False
raise

Check warning on line 90 in src/sisl/nodes/syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/syntax_nodes.py#L88-L90

Added lines #L88 - L90 were not covered by tests

def _receive_outdated(self):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
# Relevant inputs have been updated, mark this node as outdated.
if self._outdate_due_to_inputs:
return super()._receive_outdated()

# We avoid marking this node as outdated if the outdated input
# is not the one being returned.
for k in self._input_nodes:
if self._input_nodes[k]._outdated:
if k == "test":
return super()._receive_outdated()
elif k == "true":
if self._prev_evaluated_inputs["test"]:
return super()._receive_outdated()
elif k == "false":
if not self._prev_evaluated_inputs["test"]:
return super()._receive_outdated()

Check warning on line 108 in src/sisl/nodes/syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/syntax_nodes.py#L100-L108

Added lines #L100 - L108 were not covered by tests

@staticmethod
def function(test, true, false):
return true if test else false

def get_diagram_label(self):
"""Returns the label to be used in diagrams when displaying this node."""
return "if/else"

Check warning on line 116 in src/sisl/nodes/syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/syntax_nodes.py#L116

Added line #L116 was not covered by tests


class CompareSyntaxNode(SyntaxNode):
_op_to_symbol = {
"eq": "==",
"ne": "!=",
"gt": ">",
"lt": "<",
"ge": ">=",
"le": "<=",
None: "compare",
}

@staticmethod
def function(left, op: str, right):
return getattr(operator, op)(left, right)

def get_diagram_label(self):
"""Returns the label to be used in diagrams when displaying this node."""
return self._op_to_symbol.get(self._prev_evaluated_inputs.get("op"))

Check warning on line 136 in src/sisl/nodes/syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/syntax_nodes.py#L136

Added line #L136 was not covered by tests
51 changes: 50 additions & 1 deletion src/sisl/nodes/tests/test_syntax_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from sisl.nodes.syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode
from sisl.nodes.node import ConstantNode
from sisl.nodes.syntax_nodes import (
CompareSyntaxNode,
ConditionalExpressionSyntaxNode,
DictSyntaxNode,
ListSyntaxNode,
TupleSyntaxNode,
)
from sisl.nodes.workflow import Workflow


Expand All @@ -14,6 +21,38 @@
assert DictSyntaxNode(a="b", c="d", e="f").get() == {"a": "b", "c": "d", "e": "f"}


def test_cond_expr_node():
node = ConditionalExpressionSyntaxNode(test=True, true=1, false=2)

assert node.get() == 1
node.update_inputs(test=False)

assert node._outdated
assert node.get() == 2

node.update_inputs(true=3)
assert not node._outdated

# Check that only one path is evaluated.
input1 = ConstantNode(1)
input2 = ConstantNode(2)

node = ConditionalExpressionSyntaxNode(test=True, true=input1, false=input2)

assert node.get() == 1
assert input1._nupdates == 1
assert input2._nupdates == 0


def test_compare_syntax_node():
assert CompareSyntaxNode(1, "eq", 2).get() == False
assert CompareSyntaxNode(1, "ne", 2).get() == True
assert CompareSyntaxNode(1, "gt", 2).get() == False
assert CompareSyntaxNode(1, "lt", 2).get() == True
assert CompareSyntaxNode(1, "ge", 2).get() == False
assert CompareSyntaxNode(1, "le", 2).get() == True


def test_workflow_with_syntax():
def f(a):
return [a]
Expand All @@ -29,3 +68,13 @@
return {"a": a}

assert Workflow.from_func(f)(2).get() == {"a": 2}

def f(a, b, c):
return b if a else c

Check warning on line 73 in src/sisl/nodes/tests/test_syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/tests/test_syntax_nodes.py#L73

Added line #L73 was not covered by tests

assert Workflow.from_func(f)(False, 1, 2).get() == 2

def f(a, b):
return a != b

Check warning on line 78 in src/sisl/nodes/tests/test_syntax_nodes.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/tests/test_syntax_nodes.py#L78

Added line #L78 was not covered by tests

assert Workflow.from_func(f)(1, 2).get() == True
70 changes: 68 additions & 2 deletions src/sisl/nodes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from .context import temporal_context
from .node import DummyInputValue, Node
from .syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode
from .syntax_nodes import (
CompareSyntaxNode,
ConditionalExpressionSyntaxNode,
DictSyntaxNode,
ListSyntaxNode,
TupleSyntaxNode,
)
from .utils import traverse_tree_backward, traverse_tree_forward

register_environ_variable(
Expand Down Expand Up @@ -337,8 +343,9 @@
for node in nodes:
graph_node = graph.nodes[node]

node_obj = self._workflow.dryrun_nodes.get(node)

Check warning on line 346 in src/sisl/nodes/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/workflow.py#L346

Added line #L346 was not covered by tests

if node_help:
node_obj = self._workflow.dryrun_nodes.get(node)
title = (
_get_node_inputs_str(node_obj) if node_obj is not None else ""
)
Expand All @@ -353,6 +360,7 @@
"level": level,
"title": title,
"font": font,
"label": node_obj.get_diagram_label(),
**node_props,
}
)
Expand Down Expand Up @@ -996,6 +1004,15 @@
class NodeConverter(ast.NodeTransformer):
"""AST transformer that converts a function into a workflow."""

ast_to_operator = {
ast.Eq: "eq",
ast.NotEq: "ne",
ast.Lt: "lt",
ast.LtE: "le",
ast.Gt: "gt",
ast.GtE: "ge",
}

def __init__(
self,
*args,
Expand Down Expand Up @@ -1101,6 +1118,53 @@

return new_node

def visit_IfExp(self, node: ast.IfExp) -> Any:
"""Converts the if expression syntax into a call to the ConditionalExpressionSyntaxNode."""
new_node = ast.Call(
func=ast.Name(id="ConditionalExpressionSyntaxNode", ctx=ast.Load()),
args=[
self.visit(node.test),
self.visit(node.body),
self.visit(node.orelse),
],
keywords=[],
)

ast.fix_missing_locations(new_node)

return new_node

def visit_Compare(self, node: ast.Compare) -> Any:
"""Converts the comparison syntax into CompareSyntaxNode call."""
if len(node.ops) > 1:
return self.generic_visit(node)

Check warning on line 1140 in src/sisl/nodes/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/workflow.py#L1140

Added line #L1140 was not covered by tests

op = node.ops[0]
if op.__class__ not in self.ast_to_operator:
return self.generic_visit(node)

Check warning on line 1144 in src/sisl/nodes/workflow.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/nodes/workflow.py#L1144

Added line #L1144 was not covered by tests

new_node = ast.Call(
func=ast.Name(id="CompareSyntaxNode", ctx=ast.Load()),
args=[
self.visit(node.left),
ast.Constant(value=self.ast_to_operator[op.__class__], ctx=ast.Load()),
self.visit(node.comparators[0]),
],
keywords=[],
)

ast.fix_missing_locations(new_node)

# new_node = ast.Call(
# func=ast.Name(id=self.ast_to_operator[op.__class__], ctx=ast.Load()),
# args=[self.visit(node.left), self.visit(node.comparators[0])],
# keywords=[],
# )

# ast.fix_missing_locations(new_node)

return new_node


def nodify_func(
func: FunctionType,
Expand Down Expand Up @@ -1183,6 +1247,8 @@
"ListSyntaxNode": ListSyntaxNode,
"TupleSyntaxNode": TupleSyntaxNode,
"DictSyntaxNode": DictSyntaxNode,
"ConditionalExpressionSyntaxNode": ConditionalExpressionSyntaxNode,
"CompareSyntaxNode": CompareSyntaxNode,
**func_namespace,
}
if assign_fn_key is not None:
Expand Down
13 changes: 5 additions & 8 deletions src/sisl/viz/plots/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..plotutils import random_color
from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands
from ..processors.data import accept_data
from ..processors.logic import matches
from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data
from ..processors.xarray import scale_variable
from .orbital_groups_plot import OrbitalGroupsPlot
Expand Down Expand Up @@ -95,8 +94,8 @@
)

# Determine what goes on each axis
x = matches(E_axis, "x", ret_true="E", ret_false="k")
y = matches(E_axis, "y", ret_true="E", ret_false="k")
x = "E" if E_axis == "x" else "k"
y = "E" if E_axis == "y" else "k"

# Get the actions to plot lines
bands_plottings = draw_xarray_xy(
Expand Down Expand Up @@ -267,12 +266,10 @@
)

# Determine what goes on each axis
x = matches(E_axis, "x", ret_true="E", ret_false="k")
y = matches(E_axis, "y", ret_true="E", ret_false="k")
x = "E" if E_axis == "x" else "k"
y = "E" if E_axis == "y" else "k"

Check warning on line 270 in src/sisl/viz/plots/bands.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/plots/bands.py#L269-L270

Added lines #L269 - L270 were not covered by tests

sanitized_fatbands_mode = matches(
groups, [], ret_true="none", ret_false=fatbands_mode
)
sanitized_fatbands_mode = "none" if groups == [] else fatbands_mode

Check warning on line 272 in src/sisl/viz/plots/bands.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/plots/bands.py#L272

Added line #L272 was not covered by tests

# Get the actions to plot lines
fatbands_plottings = draw_xarray_xy(
Expand Down
Loading