Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add remove_node function #267

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ Utilities for manipulating neural networks.
to_pandas
add_edges
get_input_idxs
remove_node

Math
****
Expand Down
2 changes: 2 additions & 0 deletions pyhgf/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .get_input_idxs import get_input_idxs
from .get_update_sequence import get_update_sequence
from .list_branches import list_branches
from .remove_node import remove_node
from .to_pandas import to_pandas

__all__ = [
Expand All @@ -14,4 +15,5 @@
"get_update_sequence",
"list_branches",
"to_pandas",
"remove_node",
]
267 changes: 267 additions & 0 deletions pyhgf/utils/remove_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Author: Louie Mølgaard Hessellund <[email protected]>

from typing import Dict, List, Tuple, Union

from pyhgf.typing import AdjacencyLists, Edges


def _remove_edges(
attributes: Dict,
edges: Edges,
kind: str = "value",
parent_idxs=Union[int, List[int]],
children_idxs=Union[int, List[int]],
) -> Tuple[Dict, Edges]:
"""Remove a value or volatility coupling link between a set of nodes.

Parameters
----------
attributes :
Attributes of the neural network.
edges :
Edges of the neural network.
kind :
The kind of coupling to remove, can be `"value"` or `"volatility"`.
parent_idxs :
The index(es) of the parent node(s) to disconnect.
children_idxs :
The index(es) of the children node(s) to disconnect.

Returns
-------
Tuple[Dict, Edges]
Updated attributes and edges with removed connections.

"""
if kind not in ["value", "volatility"]:
raise ValueError(
f"The kind of coupling should be value or volatility, got {kind}"
)

if isinstance(children_idxs, int):
children_idxs = [children_idxs]
if isinstance(parent_idxs, int):
parent_idxs = [parent_idxs]

edges_as_list = list(edges)

# Update parent nodes
for parent_idx in parent_idxs:
if parent_idx >= len(edges_as_list):
continue

node = edges_as_list[parent_idx]
children = node.value_children if kind == "value" else node.volatility_children
coupling_key = f"{kind}_coupling_children"

if children is not None and children:
# Get indices of children to keep
keep_indices = [
i for i, child in enumerate(children) if child not in children_idxs
]
new_children = tuple(children[i] for i in keep_indices)

# Update coupling strengths if they exist
if (
coupling_key in attributes[parent_idx]
and attributes[parent_idx][coupling_key]
):
new_strengths = tuple(
attributes[parent_idx][coupling_key][i] for i in keep_indices
)
attributes[parent_idx][coupling_key] = (
new_strengths if new_strengths else None
)

# Update node edges
if kind == "value":
edges_as_list[parent_idx] = AdjacencyLists(
node.node_type,
node.value_parents,
node.volatility_parents,
new_children if new_children else None,
node.volatility_children,
node.coupling_fn,
)
else:
edges_as_list[parent_idx] = AdjacencyLists(
node.node_type,
node.value_parents,
node.volatility_parents,
node.value_children,
new_children if new_children else None,
node.coupling_fn,
)

# Update children nodes
for child_idx in children_idxs:
if child_idx >= len(edges_as_list):
continue

node = edges_as_list[child_idx]
parents = node.value_parents if kind == "value" else node.volatility_parents
coupling_key = f"{kind}_coupling_parents"

if parents is not None and parents:
# Get indices of parents to keep
keep_indices = [
i for i, parent in enumerate(parents) if parent not in parent_idxs
]
new_parents = tuple(parents[i] for i in keep_indices)

# Update coupling strengths if they exist
if (
coupling_key in attributes[child_idx]
and attributes[child_idx][coupling_key]
):
new_strengths = tuple(
attributes[child_idx][coupling_key][i] for i in keep_indices
)
attributes[child_idx][coupling_key] = (
new_strengths if new_strengths else None
)

# Update node edges
if kind == "value":
edges_as_list[child_idx] = AdjacencyLists(
node.node_type,
new_parents if new_parents else None,
node.volatility_parents,
node.value_children,
node.volatility_children,
node.coupling_fn,
)
else:
edges_as_list[child_idx] = AdjacencyLists(
node.node_type,
node.value_parents,
new_parents if new_parents else None,
node.value_children,
node.volatility_children,
node.coupling_fn,
)

return attributes, tuple(edges_as_list)


def remove_node(attributes: Dict, edges: Edges, index: int) -> Tuple[Dict, Edges]:
"""Remove a given node from the network.

This function removes a node from the network by deleting its parameters in the
attributes and edges variables, and adjusts the indices of the remaining nodes.

Parameters
----------
attributes :
The attributes of the network.
edges :
The edges of the network.
index :
The index of the node to remove.

Returns
-------
Tuple[Dict, Edges]
Updated attributes and edges with the node removed and indices adjusted.

"""
# ensure that the node exists in the network
if index not in attributes or index >= len(edges):
raise ValueError(f"Node with index {index} does not exist in the network")

edges_as_list = list(edges)
node = edges_as_list[index]

# First remove all connections to/from this node using the _remove_edges function
if node.value_parents:
attributes, edges = _remove_edges(
attributes,
edges,
"value",
parent_idxs=node.value_parents,
children_idxs=index,
)
edges_as_list = list(edges)

if node.volatility_parents:
attributes, edges = _remove_edges(
attributes,
edges,
"volatility",
parent_idxs=node.volatility_parents,
children_idxs=index,
)
edges_as_list = list(edges)

if node.value_children:
attributes, edges = _remove_edges(
attributes,
edges,
"value",
parent_idxs=index,
children_idxs=node.value_children,
)
edges_as_list = list(edges)

if node.volatility_children:
attributes, edges = _remove_edges(
attributes,
edges,
"volatility",
parent_idxs=index,
children_idxs=node.volatility_children,
)
edges_as_list = list(edges)

# Now remove the node
edges_as_list.pop(index)
attributes.pop(index)

# Create new edges list with adjusted indices
new_edges = []
for node in edges_as_list:
new_value_parents = None
new_volatility_parents = None
new_value_children = None
new_volatility_children = None

if node.value_parents:
new_value_parents = tuple(
p if p < index else p - 1 for p in node.value_parents
)

if node.volatility_parents:
new_volatility_parents = tuple(
p if p < index else p - 1 for p in node.volatility_parents
)

if node.value_children:
new_value_children = tuple(
c if c < index else c - 1 for c in node.value_children
)

if node.volatility_children:
new_volatility_children = tuple(
c if c < index else c - 1 for c in node.volatility_children
)

new_edges.append(
AdjacencyLists(
node.node_type,
new_value_parents,
new_volatility_parents,
new_value_children,
new_volatility_children,
node.coupling_fn,
)
)

# Adjust attributes indices
new_attributes = {-1: attributes[-1]} # Preserve the time_step
for old_idx, attr in attributes.items():
if old_idx == -1 or old_idx == index:
continue
new_idx = old_idx if old_idx < index else old_idx - 1
new_attributes[new_idx] = attr

return new_attributes, tuple(new_edges)
19 changes: 18 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyhgf import load_data
from pyhgf.model import Network
from pyhgf.typing import AdjacencyLists
from pyhgf.utils import list_branches
from pyhgf.utils import list_branches, remove_node


def test_imports():
Expand Down Expand Up @@ -92,3 +92,20 @@ def test_set_update_sequence():
predictions, updates = network4.update_sequence
assert len(predictions) == 1
assert len(updates) == 3


def test_remove_node():
"""Test the remove_node function."""
# a standard binary HGF
network = (
Network()
.add_nodes(n_nodes=4)
.add_nodes(value_children=2)
.add_nodes(value_children=3)
)

attributes, edges, _ = network.get_network()
new_attributes, new_edges = remove_node(attributes, edges, 1)

assert len(new_attributes) == 6
assert len(new_edges) == 5
Loading