Skip to content

Commit

Permalink
adjustments for batched results with variable-length stimuli in the c…
Browse files Browse the repository at this point in the history
…ase of `cumulative=True` in `SurprisalArray.lineplot`
  • Loading branch information
aalok-sathe committed Dec 3, 2024
1 parent 319d03c commit bd49e84
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions surprisal/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,16 @@ def lineplot(self, f=None, a=None, cumulative=False):
f, a = plt.subplots()

if cumulative:
arr = np.nan_to_num(self.surprisals.astype("float64"))
arr = self.surprisals.astype("float64")
arr = np.nan_to_num(arr, nan=0.0, posinf=float("nan"))
arr = np.cumsum(arr)
if arr[0] == 0:
arr[0] = float("nan")

else:
arr = self.surprisals
a.plot(
arr + np.random.rand(len(self)) / 10,
arr + np.random.rand(len(arr)) / 10,
".--",
lw=2,
label=" ".join(self.tokens),
Expand All @@ -135,16 +137,17 @@ def lineplot(self, f=None, a=None, cumulative=False):
a.set(
xticks=range(0, len(self.tokens)),
xlabel=("tokens"),
ylabel=(
f"{'cumulative ' if cumulative else ''}surprisal (natural log scale)"
),
ylabel=(f"{'cumulative ' if cumulative else ''}surprisal\n(natlog scale)"),
)
# plt.legend(bbox_to_anchor=(0, -0.1), loc="upper left")
plt.tight_layout()
a.grid(visible=True)

for i, (t, y) in enumerate(self):
a.annotate(t, (i, arr[i]))
if i < len(arr):
a.annotate(t, (i, arr[i]))
else:
break

return f, a

Expand Down

0 comments on commit bd49e84

Please sign in to comment.