Skip to content

Commit

Permalink
Merge branch 'benlipkin/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
aalok-sathe committed Nov 15, 2023
2 parents 3bd1d54 + ab99b0e commit 061df76
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ license = "MIT"
python = "^3.8"
transformers = "^4.20.1"
numpy = "^1.23.1"
torch = "^1.12.0"
torch = "^2.0.0"
plotext = "^5.0.2"
matplotlib = "^3.5.2"
pandas = "^1.4.3"
Expand Down
43 changes: 38 additions & 5 deletions surprisal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import numpy.typing
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
Expand Down Expand Up @@ -103,12 +104,26 @@ def __init__(
model_id: str,
model_class: typing.Callable,
device: str = "cpu",
precision: str = "fp32",
trust_remote_code: bool = False,
) -> None:
super().__init__(model_id)

precisions = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if precision not in precisions:
raise ValueError(
f"precision must be one of {list(precisions.keys())}, got {precision}"
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
# self.model_class = model_class
self.model: PreTrainedModel = model_class.from_pretrained(self.model_id)
self.model: PreTrainedModel = model_class.from_pretrained(
self.model_id,
torch_dtype=precisions[precision],
trust_remote_code=trust_remote_code,
)
self.model.eval()
self.to(device) # initializes a variable called `device`

Expand Down Expand Up @@ -166,6 +181,7 @@ def __init__(self, model_id=None, **kwargs) -> None:
if "model_class" not in kwargs:
kwargs.update(dict(model_class=AutoModelForCausalLM))
super().__init__(model_id, **kwargs)
self.tokenizer.padding_side = "right"
self.tokenizer.pad_token = self.tokenizer.eos_token

def surprise(
Expand Down Expand Up @@ -200,14 +216,28 @@ def surprise(
),
dim=1,
)
mask = torch.concat(
(
# TODO: need to evaluate what happens if this is set to 0 for the BOS token
torch.tensor([1])
.view(1, -1)
.repeat(tokenized.input_ids.shape[0], 1),
tokenized.attention_mask,
),
dim=1,
)
# raise NotImplementedError
else:
ids = tokenized.input_ids
mask = tokenized.attention_mask

ids = ids.to(self.device)
mask = mask.to(self.device)

with torch.no_grad():
output = self.model(
ids,
input_ids=ids,
attention_mask=mask,
return_dict=True,
)
tokenized = tokenized.to(self.device)
Expand All @@ -233,15 +263,18 @@ def surprise(
)
if not use_bos_token:
# padding to the left with a NULL because we removed the BOS token
logprobs = torch.concat((torch.ones(b, 1) * torch.nan, logprobs), dim=1)
logprobs = torch.concat(
((torch.ones(b, 1) * torch.nan).to(self.device), logprobs), dim=1
)

# b stands for an individual item in the batch; each sentence is one item
# since this is an autoregressive model
accumulator = []
for b in range(logprobs.shape[0]):
accumulator += [
HuggingFaceSurprisal(
tokens=tokenized[b], surprisals=-logprobs[b, :].cpu().numpy()
tokens=tokenized[b],
surprisals=-logprobs[b, :].cpu().float().numpy(),
)
]
return accumulator
Expand Down
14 changes: 14 additions & 0 deletions tests/test_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ def test_init_model(model_id):
m = surprisal.CausalHuggingFaceModel(model_id=model_id)


@pytest.mark.parametrize("model_id, stim", [("gpt2", "The cat sat on the mat.")])
def test_compute_surprisal_unconditional(model_id, stim):
import surprisal

m = surprisal.CausalHuggingFaceModel(model_id=model_id)
surp = m.surprise(stim)


@pytest.mark.parametrize(
"model_id, stim_plaus, stim_implaus, expected_surp_plaus, expected_surp_implaus",
[
Expand Down Expand Up @@ -45,3 +53,9 @@ def test_compute_surprisal_relative(model_id, stim_plaus, stim_implaus):
m = surprisal.CausalHuggingFaceModel(model_id=model_id)
[surp_plaus, surp_implaus] = m.surprise([stim_plaus, stim_implaus])
assert surp_plaus[0 : len(stim_plaus)] < surp_implaus[0 : len(stim_implaus)]


if __name__ == "__main__":
test_compute_surprisal_unconditional(
"sshleifer/tiny-gpt2", ["The cat sat.", "I am going on a bear hunt."]
)

0 comments on commit 061df76

Please sign in to comment.