-
-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathprepro.py
executable file
·55 lines (47 loc) · 2.05 KB
/
prepro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
'''
Preprocessing.
Note:
Nine key pinyin keyboard layout sample:
` ABC DEF
GHI JKL MNO
POQRS TUV WXYZ
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import codecs
#import pickle
import json
def build_vocab():
"""Builds vocabulary from the corpus.
Creates a pickle file and saves vocabulary files (dict) to it.
"""
from collections import Counter
from itertools import chain
# pinyin
if hp.isqwerty:
pnyns = "EUabcdefghijklmnopqrstuvwxyz0123456789。,!?" #E: Empty, U: Unknown
pnyn2idx = {pnyn:idx for idx, pnyn in enumerate(pnyns)}
idx2pnyn = {idx:pnyn for idx, pnyn in enumerate(pnyns)}
else:
pnyn2idx, idx2pnyn = dict(), dict()
pnyns_list = ["E", "U", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", u"。", u",", u"!", u"?"] #E: Empty, U: Unknown
for i, pnyns in enumerate(pnyns_list):
for pnyn in pnyns:
pnyn2idx[pnyn] = i
# hanzis
hanzi_sents = [line.split('\t')[2] for line in codecs.open('data/zh.tsv', 'r', 'utf-8').read().splitlines()]
hanzi2cnt = Counter(chain.from_iterable(hanzi_sents))
hanzis = [hanzi for hanzi, cnt in hanzi2cnt.items() if cnt > 5] # remove long-tail characters
hanzis.remove("_")
hanzis = ["E", "U", "_" ] + hanzis # 0: empty, 1: unknown, 2: blank
hanzi2idx = {hanzi:idx for idx, hanzi in enumerate(hanzis)}
idx2hanzi = {idx:hanzi for idx, hanzi in enumerate(hanzis)}
if hp.isqwerty:
#pickle.dump((pnyn2idx, idx2pnyn, hanzi2idx, idx2hanzi), open('data/vocab.qwerty.pkl', 'wb'), 0)
json.dump((pnyn2idx, idx2pnyn, hanzi2idx, idx2hanzi), open('data/vocab.qwerty.json', 'w'))
else:
#pickle.dump((pnyn2idx, idx2pnyn, hanzi2idx, idx2hanzi), open('data/vocab.nine.pkl', 'wb'), 0)
json.dump((pnyn2idx, idx2pnyn, hanzi2idx, idx2hanzi), open('data/vocab.nine.json', 'w'))
if __name__ == '__main__':
build_vocab(); print("Done" )