Skip to content

Commit

Permalink
formating
Browse files Browse the repository at this point in the history
  • Loading branch information
SylvainEstebe committed Nov 29, 2024
1 parent 4f5e35e commit e85bad1
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 201 deletions.
3 changes: 2 additions & 1 deletion pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,4 +568,5 @@ def add_edges(

return self

# Functions to be added

# Functions to be added
8 changes: 1 addition & 7 deletions pyhgf/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from .add_edges import add_edges

from .beliefs_propagation import beliefs_propagation

from .fill_categorical_state_node import fill_categorical_state_node

from .get_input_idxs import get_input_idxs

from .get_update_sequence import get_update_sequence

from .list_branches import list_branches

from .to_pandas import to_pandas

__all__ = [
Expand All @@ -20,4 +14,4 @@
"get_update_sequence",
"list_branches",
"to_pandas",
]
]
38 changes: 3 additions & 35 deletions pyhgf/utils/add_edges.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network
from typing import Callable, Dict, List, Optional, Tuple, Union

from pyhgf.typing import AdjacencyLists, Edges


def add_edges(
Expand Down Expand Up @@ -187,5 +157,3 @@ def add_edges(
edges = tuple(edges_as_list)

return attributes, edges


30 changes: 2 additions & 28 deletions pyhgf/utils/beliefs_propagation.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,13 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, Tuple

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.typing import Attributes, Edges, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network


@partial(jit, static_argnames=("update_sequence", "edges", "input_idxs"))
Expand Down Expand Up @@ -126,4 +101,3 @@ def beliefs_propagation(
attributes,
attributes,
) # ("carryover", "accumulated")

31 changes: 2 additions & 29 deletions pyhgf/utils/fill_categorical_state_node.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,8 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)
from pyhgf.typing import AdjacencyLists

if TYPE_CHECKING:
from pyhgf.model import Network
Expand Down
34 changes: 2 additions & 32 deletions pyhgf/utils/get_input_idxs.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import Tuple

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network
from pyhgf.typing import Edges


def get_input_idxs(edges: Edges) -> Tuple[int, ...]:
Expand Down
12 changes: 2 additions & 10 deletions pyhgf/utils/get_update_sequence.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Tuple

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.typing import Sequence
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
Expand Down
33 changes: 1 addition & 32 deletions pyhgf/utils/list_branches.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import List, Tuple

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network


def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List:
Expand Down Expand Up @@ -90,4 +60,3 @@ def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List
)

return branch_list

28 changes: 1 addition & 27 deletions pyhgf/utils/to_pandas.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,11 @@
# Author: Nicolas Legrand <[email protected]>

from functools import partial
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import jit
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
)
from pyhgf.updates.prediction.binary import binary_state_node_prediction
from pyhgf.updates.prediction.continuous import continuous_node_prediction
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error
from pyhgf.updates.prediction_error.categorical import (
categorical_state_prediction_error,
)
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error
from pyhgf.updates.prediction_error.exponential import (
prediction_error_update_exponential_family,
)

if TYPE_CHECKING:
from pyhgf.model import Network
Expand Down Expand Up @@ -128,5 +104,3 @@ def to_pandas(network: "Network") -> pd.DataFrame:
].sum(axis=1, min_count=1)

return trajectories_df


0 comments on commit e85bad1

Please sign in to comment.