diff --git a/surprisal/interface.py b/surprisal/interface.py index 4418f98..69093e6 100644 --- a/surprisal/interface.py +++ b/surprisal/interface.py @@ -119,14 +119,16 @@ def lineplot(self, f=None, a=None, cumulative=False): f, a = plt.subplots() if cumulative: - arr = np.nan_to_num(self.surprisals.astype("float64")) + arr = self.surprisals.astype("float64") + arr = np.nan_to_num(arr, nan=0.0, posinf=float("nan")) arr = np.cumsum(arr) if arr[0] == 0: arr[0] = float("nan") + else: arr = self.surprisals a.plot( - arr + np.random.rand(len(self)) / 10, + arr + np.random.rand(len(arr)) / 10, ".--", lw=2, label=" ".join(self.tokens), @@ -135,16 +137,17 @@ def lineplot(self, f=None, a=None, cumulative=False): a.set( xticks=range(0, len(self.tokens)), xlabel=("tokens"), - ylabel=( - f"{'cumulative ' if cumulative else ''}surprisal (natural log scale)" - ), + ylabel=(f"{'cumulative ' if cumulative else ''}surprisal\n(natlog scale)"), ) # plt.legend(bbox_to_anchor=(0, -0.1), loc="upper left") plt.tight_layout() a.grid(visible=True) for i, (t, y) in enumerate(self): - a.annotate(t, (i, arr[i])) + if i < len(arr): + a.annotate(t, (i, arr[i])) + else: + break return f, a