Skip to content

Commit

Permalink
do not plot surprise when there is none
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 29, 2023
1 parent 351449c commit a81c27d
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions src/pyhgf/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,28 +594,29 @@ def plot_nodes(
# plotting surprise
# -----------------
if show_surprise:
surprise_ax = axs[i].twinx()
surprise_ax.plot(
trajectories_df.time,
trajectories_df[f"x_{node_idx}_surprise"],
color="#2a2a2a",
linewidth=0.5,
zorder=-1,
label="Surprise",
)
surprise_ax.fill_between(
x=trajectories_df.time,
y1=trajectories_df[f"x_{node_idx}_surprise"],
y2=trajectories_df[f"x_{node_idx}_surprise"].min(),
color="#7f7f7f",
alpha=0.1,
zorder=-1,
)
sp = trajectories_df[f"x_{node_idx}_surprise"].sum()
surprise_ax.set_title(
f"Node {node_idx} - Surprise: {sp:.2f}",
loc="left",
)
surprise_ax.set_ylabel("Surprise")
surprise_ax.legend()
if not trajectories_df[f"x_{node_idx}_surprise"].isnull().all():
surprise_ax = axs[i].twinx()
surprise_ax.plot(
trajectories_df.time,
trajectories_df[f"x_{node_idx}_surprise"],
color="#2a2a2a",
linewidth=0.5,
zorder=-1,
label="Surprise",
)
surprise_ax.fill_between(
x=trajectories_df.time,
y1=trajectories_df[f"x_{node_idx}_surprise"],
y2=trajectories_df[f"x_{node_idx}_surprise"].min(),
color="#7f7f7f",
alpha=0.1,
zorder=-1,
)
sp = trajectories_df[f"x_{node_idx}_surprise"].sum()
surprise_ax.set_title(
f"Node {node_idx} - Surprise: {sp:.2f}",
loc="left",
)
surprise_ax.set_ylabel("Surprise")
surprise_ax.legend()
return axs

0 comments on commit a81c27d

Please sign in to comment.