-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcorpus.py
373 lines (330 loc) · 14.4 KB
/
corpus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# Load and process a corpus into trigram frequencies, subject to certain settings
# Members:
# shift_key: str
# space_key: str
# key_counts: dict[str, int]
# bigram_counts: dict[Bigram, int]
# trigram_counts: dict[Trigram, int]
# trigrams_by_freq: list[Trigram] - possibly just use trigram_counts.most_common()
# precision: int
# trigram_completeness: float
# replacements: dict[str, tuple[str, ...]]
# special_replacements: dict[str, tuple[str, ...]]
# Local vars
# raw: raw text of the corpus, directly from a file
# processed: a list of 1-grams? may not be necessary
from collections import Counter
import itertools
import json
from typing import Type
Bigram = tuple[str, str]
Trigram = tuple[str, str, str]
default_lower = """`1234567890-=qwertyuiop[]\\asdfghjkl;'zxcvbnm,./"""
default_upper = """~!@#$%^&*()_+QWERTYUIOP{}|ASDFGHJKL:"ZXCVBNM<>?"""
def display_name(key: str, corpus_settings: dict):
if key == corpus_settings.get("space_key", None):
return "space"
elif key == corpus_settings.get("shift_key", None):
return "shift"
elif key == corpus_settings.get("repeat_key", None):
return "repeat"
else:
return key
def display_str(ngram: tuple[str, ...], corpus_settings: dict):
return " ".join(display_name(key, corpus_settings) for key in ngram)
def undisplay_name(key: str, corpus_settings: dict):
if key == "space":
return corpus_settings.get("space_key", key)
elif key == "shift":
return corpus_settings.get("shift_key", key)
elif key == "repeat":
return corpus_settings.get("repeat_key", key)
else:
return key
def create_replacements(space_key: str, shift_key: str,
special_replacements: dict[str, tuple[str,...]]):
"""A dict from direct corpus characters to key sequences.
For example, "A" becomes (shift_key, "a")
Also contains all default legal sequences, like {"a": "a"}
"""
if space_key:
replacements = {" ": (space_key,)}
else:
replacements = {" ": ("unknown",)}
if shift_key:
for l, u in zip(default_lower, default_upper):
replacements[u] = (shift_key, l)
else:
for l, u in zip(default_lower, default_upper):
replacements[u] = ("unknown", l)
replacements.update(special_replacements)
legal_chars = (set(default_lower) | set(default_upper)
| set(replacements))
for char in legal_chars:
if char not in replacements:
replacements[char] = (char,)
return replacements
class TranslationError(ValueError):
"""Attempted to translate two corpuses that are not compatible."""
class Corpus:
def __init__(self, filename: str,
space_key: str = "space",
shift_key: str = "shift",
shift_policy: str = "once",
special_replacements: dict[str, tuple[str,...]] = {},
precision: int = 500,
repeat_key: str = "",
json_dict: dict = None,
other: Type["Corpus"] = None,
skipgram_weights: tuple[float] = None) -> None:
"""To disable a key, set it to `""`.
`shift_policy` can be "once" or "each". "once" means that when
consecutive capital letters occur, shift is only pressed once before
the first letter. "each" means shift is pressed before each letter.
`skipgram_weights` contains the weight in its `i`th index for pairs
of the form `pos, pos+i` for any position in the corpus. For
example, regular bigrams would be weighted by the number at index 1.
"""
self.filename = filename
self.space_key = space_key
self.shift_key = shift_key
self.shift_policy = shift_policy
self.special_replacements = special_replacements
self.repeat_key = repeat_key
self.skipgram_weights = skipgram_weights
# Not necessarily integer, due to skipgram_weights floats
self.skipgram_counts: Counter[tuple[str], float] = None
if json_dict is not None:
self._json_load(json_dict)
elif other is not None:
self._translate(other)
else:
self._process()
self.precision = precision
self.top_trigrams = ()
self.trigram_precision_total = 0
self.trigram_completeness = 0
self.set_precision(precision)
def _process(self):
self.replacements = create_replacements(
self.space_key, self.shift_key, self.special_replacements
)
replacee_lengths = sorted(set(
len(key) for key in self.replacements), reverse=True)
self.key_counts = Counter()
self.bigram_counts = Counter()
self.skip1_counts = Counter()
self.trigram_counts = Counter()
if self.skipgram_weights:
self.skipgram_counts = Counter()
def apply_replacements():
with open("corpus/" + self.filename, errors="ignore") as file:
for raw_line in file:
buffer = []
line_length = len(raw_line) - 1 # always ends with newline
i = 0
while i < line_length:
for lookahead in replacee_lengths: # longest first
if i + lookahead > line_length:
continue # remember, this can go to else
if (replacer := self.replacements.get(
raw_line[i:i+lookahead], None)) is not None:
buffer.extend(replacer)
i += lookahead
break # doesn't go to else
else:
# No replacement found
# We denote this key with "unknown"
# The rest of trialyzer knows how to handle this.
# Usually ngrams containing unknown keys are
# discarded.
buffer.add("unknown")
i += 1
yield buffer
for buffer in apply_replacements():
if bool(self.shift_key) and self.shift_policy == "once":
i = len(buffer) - 1
while i >= 2:
if (buffer[i] == self.shift_key
and buffer[i-2] == self.shift_key):
buffer.pop(i)
i -= 1
if bool(self.repeat_key):
for i in range(1, len(buffer)):
if buffer[i] == buffer[i-1]:
buffer[i] = self.repeat_key
line = tuple(buffer)
self.key_counts.update(line)
self.bigram_counts.update(itertools.pairwise(line))
self.skip1_counts.update(zip(line, line[2:]))
self.trigram_counts.update(
line[i:i+3] for i in range(len(line)-2))
if not self.skipgram_weights:
continue
for i, l1 in enumerate(line):
for sep, weight in enumerate(self.skipgram_weights):
if i+sep < len(line):
self.skipgram_counts[(l1, line[i+sep])] += weight
self.key_counts = Counter(dict(self.key_counts.most_common()))
self.bigram_counts = Counter(dict(self.bigram_counts.most_common()))
self.skip1_counts = Counter(dict(self.skip1_counts.most_common()))
self.trigram_counts = Counter(dict(self.trigram_counts.most_common()))
if self.skipgram_weights:
self.skipgram_counts = Counter(dict(
self.skipgram_counts.most_common()))
def set_precision(self, precision: int | None):
# if self.trigram_precision_total and precision == self.precision:
# return
if precision <= 0:
self.precision = 0
precision = len(self.top_trigrams)
else:
self.precision = precision
self.top_trigrams = tuple(self.trigram_counts)[:precision]
self.trigram_precision_total = sum(self.trigram_counts[tg]
for tg in self.top_trigrams)
self.trigram_completeness = (self.trigram_precision_total /
self.trigram_counts.total())
self.filtered_trigram_counts = {t: self.trigram_counts[t]
for t in self.top_trigrams}
def _json_load(self, json_dict: dict):
self.key_counts = Counter(json_dict["key_counts"])
self.bigram_counts = eval(json_dict["bigram_counts"])
self.skip1_counts = eval(json_dict["skip1_counts"])
self.trigram_counts = eval(json_dict["trigram_counts"])
self.skipgram_counts = eval(json_dict["skipgram_counts"])
def jsonable_export(self):
return {
"filename": self.filename,
"space_key": self.space_key,
"shift_key": self.shift_key,
"shift_policy": self.shift_policy,
"special_replacements": self.special_replacements,
"repeat_key": self.repeat_key,
"key_counts": self.key_counts,
"bigram_counts": repr(self.bigram_counts),
"skip1_counts": repr(self.skip1_counts),
"trigram_counts": repr(self.trigram_counts),
"skipgram_weights": self.skipgram_weights,
"skipgram_counts": repr(self.skipgram_counts)
}
def _translate(self, other: Type["Corpus"]):
if self.shift_policy != other.shift_policy:
raise TranslationError("Mismatched shifting policies")
if bool(self.space_key) != bool(other.space_key):
raise TranslationError(f"Cannot translate missing space key")
if bool(self.shift_key) != bool(other.shift_key):
raise TranslationError(f"Cannot translate missing shift key")
if self.special_replacements != other.special_replacements:
raise TranslationError("Cannot translate differing special_replacements")
if bool(self.repeat_key) != bool(other.repeat_key):
raise TranslationError("Cannot translate missing repeat key")
if self.skipgram_weights != other.skipgram_weights:
raise TranslationError("Cannot translate differing skipgram weights")
self.replacements = create_replacements(
self.space_key, self.shift_key, self.special_replacements
)
conversion: dict[str, str] = {}
conversion[other.space_key] = self.space_key
conversion[other.shift_key] = self.shift_key
conversion[other.repeat_key] = self.repeat_key
self.key_counts = Counter()
for ko, count in other.key_counts.items():
self.key_counts[conversion.get(ko, ko)] = count
self.bigram_counts = Counter()
for bo, count in other.bigram_counts.items():
self.bigram_counts[
tuple(conversion.get(ko, ko) for ko in bo)] = count
self.trigram_counts = Counter()
for to, count in other.trigram_counts.items():
self.trigram_counts[
tuple(conversion.get(ko, ko) for ko in to)] = count
self.skipgram_counts = Counter()
for so, count in other.skipgram_counts.items():
self.skipgram_counts[
tuple(conversion.get(ko, ko) for ko in so)] = count
# All corpuses, including translations
loaded = [] # type: list[Corpus]
# Exclude translations
disk_list = [] # type: list[Corpus]
def get_corpus(filename: str,
space_key: str = "space",
shift_key: str = "shift",
shift_policy: str = "once",
special_replacements: dict[str, tuple[str,...]] = {},
repeat_key: str = "",
precision: int = 500,
skipgram_weights: tuple[float] = None):
any_loaded = False
for c in loaded:
if c.filename == filename:
any_loaded = True
break
if not any_loaded:
_load_corpus_list(filename, precision)
# find exact match
for corpus_ in loaded:
if (
corpus_.filename == filename and
corpus_.space_key == space_key and
corpus_.shift_key == shift_key and
corpus_.shift_policy == shift_policy and
corpus_.special_replacements == special_replacements and
corpus_.repeat_key == repeat_key and
corpus_.skipgram_weights == skipgram_weights
):
corpus_.set_precision(precision)
return corpus_
# try translation
for corpus_ in loaded:
try:
new_ = Corpus(filename, space_key, shift_key, shift_policy,
special_replacements, precision, repeat_key, None, corpus_,
skipgram_weights)
except TranslationError:
continue # translation unsuccessful
loaded.append(new_)
return new_
# create entire new one
new_ = Corpus(filename, space_key, shift_key, shift_policy,
special_replacements, precision, repeat_key,
skipgram_weights=skipgram_weights)
loaded.append(new_)
disk_list.append(new_)
_save_corpus_list(filename)
return new_
def _load_corpus_list(filename: str, precision: int = 500):
try:
with open(f"corpus/{filename}.json") as file:
json_list: list[dict] = json.load(file)
except FileNotFoundError:
return
result = []
for c in json_list:
filename = c["filename"]
space_key = c.get("space_key", "")
shift_key = c.get("shift_key", "")
shift_policy = c["shift_policy"]
special_replacements = c.get("special_replacements", {})
repeat_key = c.get("repeat_key", "")
skipgram_weights = c.get("skipgram_weights", None)
result.append(Corpus(
filename, space_key, shift_key, shift_policy,
special_replacements, precision, repeat_key, json_dict=c,
skipgram_weights=skipgram_weights
))
loaded.extend(result)
disk_list.extend(result)
def _save_corpus_list(filename: str):
with open(f"corpus/{filename}.json", "w") as file:
json.dump(
[c.jsonable_export() for c in disk_list
if c.filename == filename],
file
)
if __name__ == "__main__":
print("Corpus test")
corp = Corpus("tr_quotes.txt")
print(corp.key_counts)
print(corp.trigram_counts.most_common(20))
print(len(corp.trigram_counts))