Skip to content

Commit

Permalink
plotting was breaking in the case of cumulative and use_bos_token=Fal…
Browse files Browse the repository at this point in the history
…se due to NaNs in `np.cumsum`; this commit provides workaround
  • Loading branch information
aalok-sathe committed Dec 2, 2024
1 parent 6c914a2 commit 319d03c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions surprisal/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model_id=None) -> None:
self.model_id = model_id

@abstractmethod
def surprise(self, textbatch: typing.Union[typing.List, str]) -> "Surprisal":
def surprise(self, textbatch: typing.Union[typing.List, str]) -> "SurprisalArray":
raise NotImplementedError


Expand Down Expand Up @@ -118,7 +118,13 @@ def lineplot(self, f=None, a=None, cumulative=False):
if f is None or a is None:
f, a = plt.subplots()

arr = np.cumsum(self.surprisals) if cumulative else self.surprisals
if cumulative:
arr = np.nan_to_num(self.surprisals.astype("float64"))
arr = np.cumsum(arr)
if arr[0] == 0:
arr[0] = float("nan")
else:
arr = self.surprisals
a.plot(
arr + np.random.rand(len(self)) / 10,
".--",
Expand Down

0 comments on commit 319d03c

Please sign in to comment.