Skip to content

Commit

Permalink
Merge pull request #14 from aalok-sathe/feature-support-kenlm
Browse files Browse the repository at this point in the history
merging this as it doesn't introduce any changes to anything current; only adds new implementation to support the kenlm model class. merging even though we have a few TODOs to address.
  • Loading branch information
aalok-sathe authored Nov 8, 2023
2 parents e9e80fd + 83ba14d commit 4cbee05
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 45 deletions.
132 changes: 108 additions & 24 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ plotext = "^5.0.2"
matplotlib = "^3.5.2"
pandas = "^1.4.3"
openai = "^0.23.0"
kenlm = {version = "^0.2.0", optional = true}

[tool.poetry.dev-dependencies]
ipython = "^8.4.0"
Expand Down
1 change: 1 addition & 0 deletions surprisal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CausalHuggingFaceModel,
MaskedHuggingFaceModel,
OpenAIModel,
KenLMModel,
)

from surprisal.interface import SurprisalQuantity
135 changes: 135 additions & 0 deletions surprisal/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@ def __index__(self):
def __len__(self):
return len(self.surprisals)

def __repr__(self) -> str:
"""
nicely formatted surprisal string with corresponding tokens/substrings
that are sliced into using the `__getitem__` method
"""
numfmt = "{: >10.3f}"
strfmt = "{: >10}"
accumulator = ""
for t in self.tokens:
accumulator += strfmt.format(t[:10]) + " "
accumulator += "\n"
for s in self.surprisals:
accumulator += numfmt.format(s) + " "
return accumulator

def __getitem__(
self, slctup: typing.Tuple[typing.Union[slice, int], str]
) -> SurprisalQuantity:
"""Returns the aggregated surprisal over a character
Args:
slctup (typing.Tuple[typing.Union[slice, int], str]):
`(slc, slctype) = slctup`: a tuple of a `slc` (slice) and a `slctype` (str).
`slc` gives the slice of the original string we want to aggregate surprisal over.
`slctype` indicates if it should be a "char" slice or a "word" slice.
if a character falls inside a token, then that entire token is included.
Returns:
float: the aggregated surprisal over the word span
"""
raise NotImplementedError

@property
@abstractmethod
def tokens(self):
Expand Down Expand Up @@ -74,3 +106,106 @@ def lineplot(self, f=None, a=None, cumulative=False):
a.annotate(t, (i, arr[i]))

return f, a


class CustomEncoding:
"""
a duck-typed clone of the huggingface tokenizers' return class
`tokenizers.Encoding`
that packages simple custom-tokenized text together with its
character and word spans allowing indexing into the tokenized
object by character and word spans
the goal is for this class to be capable of being passed to
`hf_pick_matching_token_ixs` with the signature
```python
surprisal.utils.hf_pick_matching_token_ixs(
encoding: "tokenizers.Encoding", span_of_interest: slice, span_type: str
) -> slice
```
and that's about it. it does not provide implementations of anything else,
since huggingface makes it really difficult to actually re-use any of the
Rust implementation of tokeizers in Python
Arguments:
----------
`tokens` (typing.Iterable[str]): the tokens in the tokenized text
`spans` (typing.Iterable[typing.Tuple[int]]): the character spans of each token
`original_str` (str): the original string that was tokenized
E.g., the input to tokens and spans would be the result of the following output from
`tokenizers.pre_tokenizers.Whitespace().pre_tokenize_str("hi my name is language model")`:
[('hi', (0, 2)),
('my', (3, 5)),
('name', (6, 10)),
('is', (11, 13)),
('language', (14, 22)),
('model', (23, 29))]
"""

def __init__(
self,
tokens: typing.Iterable[str],
spans: typing.Iterable[typing.Tuple[int]],
original_str: str,
ids: typing.Iterable[int] = None,
) -> None:
self.tokens = tokens
self.spans = spans
self.original_str = original_str
self._ids = ids

def token_to_chars(self, token_index) -> typing.Tuple[int, int]:
"""
Get the offsets of the token at the given index.
The returned offsets are related to the input sequence that contains the
token. In order to determine in which input sequence it belongs, you
must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
Args:
token_index (:obj:`int`):
The index of a token in the encoded sequence.
Returns:
:obj:`Tuple[int, int]`: The token offsets :obj:`(first, last + 1)`
"""
return self.spans[token_index]

def token_to_word(self, token_index):
"""
Get the index of the word that contains the token in one of the input sequences.
The returned word index is related to the input sequence that contains
the token. In order to determine in which input sequence it belongs, you
must call :meth:`~tokenizers.Encoding.token_to_sequence()`.
Args:
token_index (:obj:`int`):
The index of a token in the encoded sequence.
Returns:
:obj:`int`: The index of the word in the relevant input sequence.
"""
# assuming this is going to be primarily used for whitespace-tokenized text
# TODO: this method will need to be fleshed out using the character spans to
# match the tokens to their corresponding words if we ever want to support a
# custom tokenization scheme that isn't just whitespace.
# this is possible, but will skip implementing for now
return token_index

@property
def ids(self):
"""
The generated IDs
The IDs are the main input to a Language Model. They are the token indices,
the numerical representations that a LM understands.
Returns:
:obj:`List[int]`: The list of IDs
"""
# IDs are not applicable to non-LM tokenization, unless explicitly specified
if self._ids:
return self._ids
return [0] * len(self.tokens)
79 changes: 75 additions & 4 deletions surprisal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,93 @@
import logging
from abc import abstractmethod
from functools import partial
from pathlib import Path

import numpy as np
import numpy.typing
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoTokenizer,
PreTrainedModel,
)
from tokenizers.pre_tokenizers import Whitespace, PreTokenizer

from surprisal.utils import pick_matching_token_ixs, openai_models_list
from surprisal.interface import Model, SurprisalArray, SurprisalQuantity
from surprisal.surprisal import HuggingFaceSurprisal
from surprisal.utils import hf_pick_matching_token_ixs, openai_models_list
from surprisal.interface import Model, SurprisalArray, SurprisalQuantity, CustomEncoding
from surprisal.surprisal import HuggingFaceSurprisal, NGramSurprisal

logger = logging.getLogger(name="surprisal")


class KenLMModel(Model):
"""
A class utilizing the `kenlm` library to compute surprisal using
pretrained kenlm models
"""

def __init__(self, model_path: typing.Union[str, Path], **kwargs) -> None:
super().__init__(str(model_path))

import kenlm

self.tokenizer = Whitespace()

self.model = kenlm.Model(model_path)
self.state_in = kenlm.State()
self.state_out = kenlm.State()

def tokenize(
self, textbatch: typing.Union[typing.List, str]
) -> typing.Iterable[CustomEncoding]:
if type(textbatch) is str:
textbatch = [textbatch]

tokenized = map(self.tokenizer.pre_tokenize_str, textbatch)

for tokens_and_spans in tokenized:
tokens_and_spans = [*zip(*tokens_and_spans)]
tokens = tokens_and_spans[0]
spans = tokens_and_spans[1]
yield CustomEncoding(tokens, spans, textbatch[0])

def surprise(
self, textbatch: typing.Union[typing.List, str]
) -> typing.List[NGramSurprisal]:
import kenlm

if type(textbatch) is str:
textbatch = [textbatch]

def score_sent(
sent: CustomEncoding,
m: kenlm.Model = self.model,
bos: bool = True,
eos: bool = True,
) -> np.typing.NDArray[float]:
st1, st2 = kenlm.State(), kenlm.State()
if bos:
m.BeginSentenceWrite(st1)
else:
m.NullContextWrite(st1)
words = sent.tokens
accum = []
for w in words:
accum += [m.BaseScore(st1, w, st2)]
st1, st2 = st2, st1
if eos:
accum += [m.BaseScore(st1, "</s>", st2)]
return np.array(accum)

tokenized = [*self.tokenize(textbatch)]
scores = [*map(score_sent, tokenized)]

accumulator = []
for b in range(len(textbatch)):
accumulator += [NGramSurprisal(tokens=tokenized[b], surprisals=-scores[b])]
return accumulator


###############################################################################
### model classes to compute surprisal
###############################################################################
Expand Down Expand Up @@ -66,7 +137,7 @@ def tokenize(self, textbatch: typing.Union[typing.List, str], max_length=1024):
@abstractmethod
def surprise(
self, textbatch: typing.Union[typing.List, str]
) -> HuggingFaceSurprisal:
) -> typing.List[HuggingFaceSurprisal]:
raise NotImplementedError

def extract_surprisal(
Expand Down
68 changes: 52 additions & 16 deletions surprisal/surprisal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from functools import partial

import numpy as np
from surprisal.utils import pick_matching_token_ixs
from surprisal.interface import Model, SurprisalArray, SurprisalQuantity
from surprisal.utils import hf_pick_matching_token_ixs
from surprisal.interface import CustomEncoding, Model, SurprisalArray, SurprisalQuantity

logger = logging.getLogger(name="surprisal")

Expand All @@ -31,13 +31,15 @@ def tokens(self):
return self._tokens.tokens

@property
def surprisals(self):
def surprisals(self) -> np.typing.NDArray[SurprisalQuantity]:
return self._surprisals

def __iter__(self) -> typing.Tuple[str, float]:
return zip(self.tokens, self.surprisals)

def __getitem__(self, slctup: typing.Tuple[typing.Union[slice, int], str]):
def __getitem__(
self, slctup: typing.Tuple[typing.Union[slice, int], str]
) -> SurprisalQuantity:
"""Returns the aggregated surprisal over a character
Args:
Expand All @@ -58,9 +60,9 @@ def __getitem__(self, slctup: typing.Tuple[typing.Union[slice, int], str]):
slc, slctype = slctup, "char"

if slctype == "char":
fn = partial(pick_matching_token_ixs, span_type="char")
fn = partial(hf_pick_matching_token_ixs, span_type="char")
elif slctype == "word":
fn = partial(pick_matching_token_ixs, span_type="word")
fn = partial(hf_pick_matching_token_ixs, span_type="word")

if type(slc) is int:
slc = slice(slc, slc + 1)
Expand All @@ -70,16 +72,50 @@ def __getitem__(self, slctup: typing.Tuple[typing.Union[slice, int], str]):
self.surprisals[token_slc].sum(), " ".join(self.tokens[token_slc])
)

def __repr__(self) -> str:
numfmt = "{: >10.3f}"
strfmt = "{: >10}"
accumulator = ""
for t in self.tokens:
accumulator += strfmt.format(t[:10]) + " "
accumulator += "\n"
for s in self.surprisals:
accumulator += numfmt.format(s) + " "
return accumulator

class NGramSurprisal(HuggingFaceSurprisal):
def __init__(
self,
tokens: typing.List[CustomEncoding],
surprisals: np.ndarray,
) -> None:
super().__init__(tokens, surprisals.astype(SurprisalQuantity))

def __getitem__(
self, slctup: typing.Tuple[typing.Union[slice, int], typing.Optional[str]]
):
"""Returns the aggregated surprisal over a character
Args:
slctup (typing.Tuple[typing.Union[slice, int], str]):
`(slc, slctype) = slctup`: a tuple of a `slc` (slice) and a `slctype` (str).
`slc` gives the slice of the original string we want to aggregate surprisal over.
`slctype` indicates if it should be a "char" slice or a "word" slice.
if a character falls inside a token, then that entire token is included.
Returns:
float: the aggregated surprisal over the word span
"""
try:
slc, slctype = slctup
if slctype not in ("word", "char"):
raise ValueError(f"unrecognized slice type {slctype}")
except TypeError:
# slctup is not a tuple, but just a slice or int
slc, slctype = slctup, "char"

if slctype == "char":
raise NotImplementedError('WIP; currently only supports "word" spans')
fn = partial(hf_pick_matching_token_ixs, span_type="char")
elif slctype == "word":
token_slc = slc

if type(slc) is int:
slc = slice(slc, slc + 1)

return SurprisalQuantity(
self.surprisals[token_slc].sum(), " ".join(self.tokens[token_slc])
)


class PCFGSurprisal(SurprisalArray):
Expand Down
3 changes: 2 additions & 1 deletion surprisal/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import tokenizers
from transformers import tokenization_utils_base


def pick_matching_token_ixs(
def hf_pick_matching_token_ixs(
encoding: "tokenizers.Encoding", span_of_interest: slice, span_type: str
) -> slice:
"""Picks token indices in a tokenized encoded sequence that best correspond to
Expand Down

0 comments on commit 4cbee05

Please sign in to comment.