Skip to content

Commit

Permalink
elevate bos and eos args to surprise()
Browse files Browse the repository at this point in the history
  • Loading branch information
aalok-sathe committed Nov 17, 2023
1 parent 1f09253 commit 6b986af
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions surprisal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def tokenize(
def surprise(
self,
textbatch: typing.Union[typing.List, str],
use_bos_token: bool = True,
use_eos_token: bool = True,
) -> typing.List[NGramSurprisal]:
import kenlm

Expand All @@ -72,11 +74,9 @@ def surprise(
def score_sent(
sent: CustomEncoding,
m: kenlm.Model = self.model,
bos: bool = True,
eos: bool = True,
) -> np.typing.NDArray[float]:
st1, st2 = kenlm.State(), kenlm.State()
if bos:
if use_bos_token:
m.BeginSentenceWrite(st1)
else:
m.NullContextWrite(st1)
Expand All @@ -85,7 +85,7 @@ def score_sent(
for w in words:
accum += [m.BaseScore(st1, w, st2)]
st1, st2 = st2, st1
if eos:
if use_eos_token:
accum += [m.BaseScore(st1, "</s>", st2)]
return np.array(accum)

Expand Down

0 comments on commit 6b986af

Please sign in to comment.