diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index 454223e21..c6700cc8c 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -568,4 +568,5 @@ def add_edges( return self -# Functions to be added \ No newline at end of file + +# Functions to be added diff --git a/pyhgf/utils/__init__.py b/pyhgf/utils/__init__.py index a7ffa5616..1b77b75d4 100644 --- a/pyhgf/utils/__init__.py +++ b/pyhgf/utils/__init__.py @@ -1,15 +1,9 @@ from .add_edges import add_edges - from .beliefs_propagation import beliefs_propagation - from .fill_categorical_state_node import fill_categorical_state_node - from .get_input_idxs import get_input_idxs - from .get_update_sequence import get_update_sequence - from .list_branches import list_branches - from .to_pandas import to_pandas __all__ = [ @@ -20,4 +14,4 @@ "get_update_sequence", "list_branches", "to_pandas", -] \ No newline at end of file +] diff --git a/pyhgf/utils/add_edges.py b/pyhgf/utils/add_edges.py index 9082a4556..a1b4940ec 100644 --- a/pyhgf/utils/add_edges.py +++ b/pyhgf/utils/add_edges.py @@ -1,38 +1,8 @@ # Author: Nicolas Legrand -from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union - -import jax.numpy as jnp -import numpy as np -import pandas as pd -from jax import jit -from jax.tree_util import Partial -from jax.typing import ArrayLike - -from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence -from pyhgf.updates.observation import set_observation -from pyhgf.updates.posterior.categorical import categorical_state_update -from pyhgf.updates.posterior.continuous import ( - continuous_node_posterior_update, - continuous_node_posterior_update_ehgf, -) -from pyhgf.updates.prediction.binary import binary_state_node_prediction -from pyhgf.updates.prediction.continuous import continuous_node_prediction -from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction -from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error -from pyhgf.updates.prediction_error.categorical import ( - categorical_state_prediction_error, -) -from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error -from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error -from pyhgf.updates.prediction_error.exponential import ( - prediction_error_update_exponential_family, -) - -if TYPE_CHECKING: - from pyhgf.model import Network +from typing import Callable, Dict, List, Optional, Tuple, Union + +from pyhgf.typing import AdjacencyLists, Edges def add_edges( @@ -187,5 +157,3 @@ def add_edges( edges = tuple(edges_as_list) return attributes, edges - - diff --git a/pyhgf/utils/beliefs_propagation.py b/pyhgf/utils/beliefs_propagation.py index b3d73db22..405420e70 100644 --- a/pyhgf/utils/beliefs_propagation.py +++ b/pyhgf/utils/beliefs_propagation.py @@ -1,38 +1,13 @@ # Author: Nicolas Legrand from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import Dict, Tuple -import jax.numpy as jnp -import numpy as np -import pandas as pd from jax import jit -from jax.tree_util import Partial from jax.typing import ArrayLike -from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence +from pyhgf.typing import Attributes, Edges, UpdateSequence from pyhgf.updates.observation import set_observation -from pyhgf.updates.posterior.categorical import categorical_state_update -from pyhgf.updates.posterior.continuous import ( - continuous_node_posterior_update, - continuous_node_posterior_update_ehgf, -) -from pyhgf.updates.prediction.binary import binary_state_node_prediction -from pyhgf.updates.prediction.continuous import continuous_node_prediction -from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction -from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error -from pyhgf.updates.prediction_error.categorical import ( - categorical_state_prediction_error, -) -from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error -from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error -from pyhgf.updates.prediction_error.exponential import ( - prediction_error_update_exponential_family, -) - -if TYPE_CHECKING: - from pyhgf.model import Network @partial(jit, static_argnames=("update_sequence", "edges", "input_idxs")) @@ -126,4 +101,3 @@ def beliefs_propagation( attributes, attributes, ) # ("carryover", "accumulated") - diff --git a/pyhgf/utils/fill_categorical_state_node.py b/pyhgf/utils/fill_categorical_state_node.py index 35f41736d..9c6cf31d5 100644 --- a/pyhgf/utils/fill_categorical_state_node.py +++ b/pyhgf/utils/fill_categorical_state_node.py @@ -1,35 +1,8 @@ # Author: Nicolas Legrand -from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List -import jax.numpy as jnp -import numpy as np -import pandas as pd -from jax import jit -from jax.tree_util import Partial -from jax.typing import ArrayLike - -from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence -from pyhgf.updates.observation import set_observation -from pyhgf.updates.posterior.categorical import categorical_state_update -from pyhgf.updates.posterior.continuous import ( - continuous_node_posterior_update, - continuous_node_posterior_update_ehgf, -) -from pyhgf.updates.prediction.binary import binary_state_node_prediction -from pyhgf.updates.prediction.continuous import continuous_node_prediction -from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction -from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error -from pyhgf.updates.prediction_error.categorical import ( - categorical_state_prediction_error, -) -from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error -from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error -from pyhgf.updates.prediction_error.exponential import ( - prediction_error_update_exponential_family, -) +from pyhgf.typing import AdjacencyLists if TYPE_CHECKING: from pyhgf.model import Network diff --git a/pyhgf/utils/get_input_idxs.py b/pyhgf/utils/get_input_idxs.py index c10b1b878..1ab054826 100644 --- a/pyhgf/utils/get_input_idxs.py +++ b/pyhgf/utils/get_input_idxs.py @@ -1,38 +1,8 @@ # Author: Nicolas Legrand -from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import Tuple -import jax.numpy as jnp -import numpy as np -import pandas as pd -from jax import jit -from jax.tree_util import Partial -from jax.typing import ArrayLike - -from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence -from pyhgf.updates.observation import set_observation -from pyhgf.updates.posterior.categorical import categorical_state_update -from pyhgf.updates.posterior.continuous import ( - continuous_node_posterior_update, - continuous_node_posterior_update_ehgf, -) -from pyhgf.updates.prediction.binary import binary_state_node_prediction -from pyhgf.updates.prediction.continuous import continuous_node_prediction -from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction -from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error -from pyhgf.updates.prediction_error.categorical import ( - categorical_state_prediction_error, -) -from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error -from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error -from pyhgf.updates.prediction_error.exponential import ( - prediction_error_update_exponential_family, -) - -if TYPE_CHECKING: - from pyhgf.model import Network +from pyhgf.typing import Edges def get_input_idxs(edges: Edges) -> Tuple[int, ...]: diff --git a/pyhgf/utils/get_update_sequence.py b/pyhgf/utils/get_update_sequence.py index 82cb701c5..1ad304ccc 100644 --- a/pyhgf/utils/get_update_sequence.py +++ b/pyhgf/utils/get_update_sequence.py @@ -1,18 +1,10 @@ # Author: Nicolas Legrand -from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple -import jax.numpy as jnp -import numpy as np -import pandas as pd -from jax import jit from jax.tree_util import Partial -from jax.typing import ArrayLike -from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence -from pyhgf.updates.observation import set_observation +from pyhgf.typing import Sequence from pyhgf.updates.posterior.categorical import categorical_state_update from pyhgf.updates.posterior.continuous import ( continuous_node_posterior_update, diff --git a/pyhgf/utils/list_branches.py b/pyhgf/utils/list_branches.py index be7d0e7c4..24e0953c5 100644 --- a/pyhgf/utils/list_branches.py +++ b/pyhgf/utils/list_branches.py @@ -1,38 +1,8 @@ # Author: Nicolas Legrand -from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import List, Tuple -import jax.numpy as jnp import numpy as np -import pandas as pd -from jax import jit -from jax.tree_util import Partial -from jax.typing import ArrayLike - -from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence -from pyhgf.updates.observation import set_observation -from pyhgf.updates.posterior.categorical import categorical_state_update -from pyhgf.updates.posterior.continuous import ( - continuous_node_posterior_update, - continuous_node_posterior_update_ehgf, -) -from pyhgf.updates.prediction.binary import binary_state_node_prediction -from pyhgf.updates.prediction.continuous import continuous_node_prediction -from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction -from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error -from pyhgf.updates.prediction_error.categorical import ( - categorical_state_prediction_error, -) -from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error -from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error -from pyhgf.updates.prediction_error.exponential import ( - prediction_error_update_exponential_family, -) - -if TYPE_CHECKING: - from pyhgf.model import Network def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List: @@ -90,4 +60,3 @@ def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List ) return branch_list - diff --git a/pyhgf/utils/to_pandas.py b/pyhgf/utils/to_pandas.py index 5f4035f43..539edfdd2 100644 --- a/pyhgf/utils/to_pandas.py +++ b/pyhgf/utils/to_pandas.py @@ -1,35 +1,11 @@ # Author: Nicolas Legrand -from functools import partial -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING import jax.numpy as jnp -import numpy as np import pandas as pd -from jax import jit -from jax.tree_util import Partial -from jax.typing import ArrayLike from pyhgf.math import binary_surprise, gaussian_surprise -from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence -from pyhgf.updates.observation import set_observation -from pyhgf.updates.posterior.categorical import categorical_state_update -from pyhgf.updates.posterior.continuous import ( - continuous_node_posterior_update, - continuous_node_posterior_update_ehgf, -) -from pyhgf.updates.prediction.binary import binary_state_node_prediction -from pyhgf.updates.prediction.continuous import continuous_node_prediction -from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction -from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error -from pyhgf.updates.prediction_error.categorical import ( - categorical_state_prediction_error, -) -from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error -from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error -from pyhgf.updates.prediction_error.exponential import ( - prediction_error_update_exponential_family, -) if TYPE_CHECKING: from pyhgf.model import Network @@ -128,5 +104,3 @@ def to_pandas(network: "Network") -> pd.DataFrame: ].sum(axis=1, min_count=1) return trajectories_df - -