-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4f5e35e
commit e85bad1
Showing
9 changed files
with
16 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -568,4 +568,5 @@ def add_edges( | |
|
||
return self | ||
|
||
# Functions to be added | ||
|
||
# Functions to be added |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,8 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
||
from pyhgf.typing import AdjacencyLists, Edges | ||
|
||
|
||
def add_edges( | ||
|
@@ -187,5 +157,3 @@ def add_edges( | |
edges = tuple(edges_as_list) | ||
|
||
return attributes, edges | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,13 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import Dict, Tuple | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.typing import Attributes, Edges, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
|
||
|
||
@partial(jit, static_argnames=("update_sequence", "edges", "input_idxs")) | ||
|
@@ -126,4 +101,3 @@ def beliefs_propagation( | |
attributes, | ||
attributes, | ||
) # ("carryover", "accumulated") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,8 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import TYPE_CHECKING, Dict, List | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
from pyhgf.typing import AdjacencyLists | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,8 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import Tuple | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
from pyhgf.typing import Edges | ||
|
||
|
||
def get_input_idxs(edges: Edges) -> Tuple[int, ...]: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,10 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import TYPE_CHECKING, List, Tuple | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.typing import Sequence | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,8 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import List, Tuple | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
|
||
|
||
def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List: | ||
|
@@ -90,4 +60,3 @@ def list_branches(node_idxs: List, edges: Tuple, branch_list: List = []) -> List | |
) | ||
|
||
return branch_list | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,11 @@ | ||
# Author: Nicolas Legrand <[email protected]> | ||
|
||
from functools import partial | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union | ||
from typing import TYPE_CHECKING | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pandas as pd | ||
from jax import jit | ||
from jax.tree_util import Partial | ||
from jax.typing import ArrayLike | ||
|
||
from pyhgf.math import binary_surprise, gaussian_surprise | ||
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence | ||
from pyhgf.updates.observation import set_observation | ||
from pyhgf.updates.posterior.categorical import categorical_state_update | ||
from pyhgf.updates.posterior.continuous import ( | ||
continuous_node_posterior_update, | ||
continuous_node_posterior_update_ehgf, | ||
) | ||
from pyhgf.updates.prediction.binary import binary_state_node_prediction | ||
from pyhgf.updates.prediction.continuous import continuous_node_prediction | ||
from pyhgf.updates.prediction.dirichlet import dirichlet_node_prediction | ||
from pyhgf.updates.prediction_error.binary import binary_state_node_prediction_error | ||
from pyhgf.updates.prediction_error.categorical import ( | ||
categorical_state_prediction_error, | ||
) | ||
from pyhgf.updates.prediction_error.continuous import continuous_node_prediction_error | ||
from pyhgf.updates.prediction_error.dirichlet import dirichlet_node_prediction_error | ||
from pyhgf.updates.prediction_error.exponential import ( | ||
prediction_error_update_exponential_family, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from pyhgf.model import Network | ||
|
@@ -128,5 +104,3 @@ def to_pandas(network: "Network") -> pd.DataFrame: | |
].sum(axis=1, min_count=1) | ||
|
||
return trajectories_df | ||
|
||
|