diff --git a/surprisal/interface.py b/surprisal/interface.py index ddbde1e..4418f98 100644 --- a/surprisal/interface.py +++ b/surprisal/interface.py @@ -17,7 +17,7 @@ def __init__(self, model_id=None) -> None: self.model_id = model_id @abstractmethod - def surprise(self, textbatch: typing.Union[typing.List, str]) -> "Surprisal": + def surprise(self, textbatch: typing.Union[typing.List, str]) -> "SurprisalArray": raise NotImplementedError @@ -118,7 +118,13 @@ def lineplot(self, f=None, a=None, cumulative=False): if f is None or a is None: f, a = plt.subplots() - arr = np.cumsum(self.surprisals) if cumulative else self.surprisals + if cumulative: + arr = np.nan_to_num(self.surprisals.astype("float64")) + arr = np.cumsum(arr) + if arr[0] == 0: + arr[0] = float("nan") + else: + arr = self.surprisals a.plot( arr + np.random.rand(len(self)) / 10, ".--",