-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
90 lines (77 loc) · 3.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import string
import torch
import numpy as np
import re
from typing import List, Dict, Tuple
import random
from tqdm import tqdm
def _featurize(
toks: List[str], tok2id: Dict[str, int], device: str = "cpu"
) -> torch.tensor:
tok_ids = []
for tok in toks:
try:
tok_ids.append(tok2id[tok.lower()])
except KeyError:
tok_ids.append(0)
return torch.Tensor(tok_ids).to(int).to(device)
def load_glove_embeddings(filepath: str, embed_dim: int) -> Tuple[torch.tensor, str]:
print("Loading Glove...")
def get_num_lines(f):
"""take a peek through file handle `f` for the total number of lines"""
num_lines = sum(1 for _ in f)
f.seek(0)
return num_lines
itos = []
with open(filepath, "r") as f:
num_lines = get_num_lines(f)
vectors = torch.zeros(num_lines, embed_dim, dtype=torch.float32)
for i, l in enumerate(tqdm(f, total=num_lines)):
l = l.split(" ") # using bytes here is tedious but avoids unicode error
word, vector = l[0], l[1:]
itos.append(word)
vectors[i] = torch.tensor([float(x) for x in vector])
print(f"{len(itos)} words loaded!")
return (vectors, itos)
def get_pretrained_weights(embed_dim: int, tok2id: Dict[str, int]) -> torch.Tensor:
glove_path = f"../glove/glove.6B.{embed_dim}d.txt"
if not os.path.exists(glove_path):
raise ValueError(f"Glove file does not exist: {glove_path}")
vectors, itos = load_glove_embeddings(filepath=glove_path, embed_dim=embed_dim)
# So that we can initiate OOV words (relative to glove) with same distribution as others
glove_mean = vectors.mean()
glove_std = vectors.std()
weights = torch.zeros((len(tok2id), embed_dim), dtype=torch.float32)
found = 0
for tok, ix in tok2id.items():
tok = tok.lower()
if tok in itos:
weights[ix, :] = vectors[itos.index(tok)]
found += 1
else:
print(f"Word not in glove: {tok}")
weights[ix, :] = torch.normal(glove_mean, glove_std, size=(embed_dim,))
print(f"{found} out of {len(tok2id)} words found.")
return weights
def detokenize(tokens: List[str]) -> List[int]:
"""
Returns list of integers corresponding to whether or not a token should get a whitespace before it.
"""
return [
1 if not i.startswith("'") and i not in string.punctuation and i != "n't" else 0
for i in tokens
]
def get_example_script(corpus: "TACLCorpus", script_type: str) -> str:
"""
Returns an example script from the test set of InScript, with the next token being the mask.
"""
genre_docs = [doc for doc in corpus.test if doc.name.startswith(script_type)]
chosen_doc = genre_docs[random.randint(0, len(genre_docs) - 1)]
masked_word_ixs = [ix for ix, word in enumerate(chosen_doc) if word.masked]
chosen_word_ix = masked_word_ixs[random.randint(0, len(masked_word_ixs) - 1)]
tokens = [chosen_doc[i].text for i in range(chosen_word_ix)]
detoks = detokenize(tokens)
return "".join(
[" " + tokens[i] if detoks[i] == 1 else tokens[i] for i in range(len(detoks))]
).strip()