diff --git a/src/pyhgf/plots.py b/src/pyhgf/plots.py index 4d900ae5a..02e7646e0 100644 --- a/src/pyhgf/plots.py +++ b/src/pyhgf/plots.py @@ -594,28 +594,29 @@ def plot_nodes( # plotting surprise # ----------------- if show_surprise: - surprise_ax = axs[i].twinx() - surprise_ax.plot( - trajectories_df.time, - trajectories_df[f"x_{node_idx}_surprise"], - color="#2a2a2a", - linewidth=0.5, - zorder=-1, - label="Surprise", - ) - surprise_ax.fill_between( - x=trajectories_df.time, - y1=trajectories_df[f"x_{node_idx}_surprise"], - y2=trajectories_df[f"x_{node_idx}_surprise"].min(), - color="#7f7f7f", - alpha=0.1, - zorder=-1, - ) - sp = trajectories_df[f"x_{node_idx}_surprise"].sum() - surprise_ax.set_title( - f"Node {node_idx} - Surprise: {sp:.2f}", - loc="left", - ) - surprise_ax.set_ylabel("Surprise") - surprise_ax.legend() + if not trajectories_df[f"x_{node_idx}_surprise"].isnull().all(): + surprise_ax = axs[i].twinx() + surprise_ax.plot( + trajectories_df.time, + trajectories_df[f"x_{node_idx}_surprise"], + color="#2a2a2a", + linewidth=0.5, + zorder=-1, + label="Surprise", + ) + surprise_ax.fill_between( + x=trajectories_df.time, + y1=trajectories_df[f"x_{node_idx}_surprise"], + y2=trajectories_df[f"x_{node_idx}_surprise"].min(), + color="#7f7f7f", + alpha=0.1, + zorder=-1, + ) + sp = trajectories_df[f"x_{node_idx}_surprise"].sum() + surprise_ax.set_title( + f"Node {node_idx} - Surprise: {sp:.2f}", + loc="left", + ) + surprise_ax.set_ylabel("Surprise") + surprise_ax.legend() return axs