From 6b986af51832fd757e05b5b6f99b85458d003a9c Mon Sep 17 00:00:00 2001 From: aalok-sathe Date: Thu, 16 Nov 2023 20:48:16 -0500 Subject: [PATCH] elevate bos and eos args to `surprise()` --- surprisal/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/surprisal/model.py b/surprisal/model.py index 94cf0f2..30062f3 100644 --- a/surprisal/model.py +++ b/surprisal/model.py @@ -63,6 +63,8 @@ def tokenize( def surprise( self, textbatch: typing.Union[typing.List, str], + use_bos_token: bool = True, + use_eos_token: bool = True, ) -> typing.List[NGramSurprisal]: import kenlm @@ -72,11 +74,9 @@ def surprise( def score_sent( sent: CustomEncoding, m: kenlm.Model = self.model, - bos: bool = True, - eos: bool = True, ) -> np.typing.NDArray[float]: st1, st2 = kenlm.State(), kenlm.State() - if bos: + if use_bos_token: m.BeginSentenceWrite(st1) else: m.NullContextWrite(st1) @@ -85,7 +85,7 @@ def score_sent( for w in words: accum += [m.BaseScore(st1, w, st2)] st1, st2 = st2, st1 - if eos: + if use_eos_token: accum += [m.BaseScore(st1, "", st2)] return np.array(accum)