-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
168 lines (136 loc) · 5.28 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from sacred import Experiment
from trankit import Pipeline, verify_customized_pipeline
from pathlib import Path
from typing import Dict
from collections import Counter
import re
ex = Experiment()
@ex.config
def config():
category = 'customized-ner' # noqa
lang = None # noqa
save_dir = './save_dir' # noqa
raw_data_dir = None # noqa
output_data_dir = None # noqa
def symlink_pipeline_part(lang, save_dir, category, part):
mapping = {
'bg': 'bulgarian',
'cs': 'czech',
'pl': 'polish',
'ru': 'russian',
'sl': 'slovenian',
'uk': 'ukrainian'
}
lang_prefix = str(Path('./langs') / mapping[lang] / mapping[lang])
to_path = str(Path(save_dir) / category / category)
p_to = Path(lang_prefix + part).resolve()
p_from = Path(to_path + part)
try:
p_from.symlink_to(p_to)
except FileExistsError:
p_from.unlink()
p_from.symlink_to(p_to)
def predict_on_text(pipeline: Pipeline, text: str) -> Dict:
predictions = []
r = pipeline(text)
current_mwe = dict()
for sentence in r['sentences']:
for token in sentence['tokens']:
token_ner = token['ner']
if token_ner.startswith('S-'):
predictions.append([
token['text'],
token['lemma'],
token_ner.replace('S-', '')
])
elif token_ner.startswith('B-'):
if 'text' in current_mwe:
current_mwe['text'] += ' ' + token['text']
current_mwe['lemma'] += ' ' + token['lemma']
else:
current_mwe['text'] = token['text']
current_mwe['lemma'] = token['lemma']
elif token_ner.startswith('I-'):
if 'text' in current_mwe:
current_mwe['text'] += ' ' + token['text']
current_mwe['lemma'] += ' ' + token['lemma']
else:
current_mwe['text'] = token['text']
current_mwe['lemma'] = token['lemma']
elif token_ner.startswith('E-'):
if 'text' in current_mwe:
current_mwe['text'] += ' ' + token['text']
current_mwe['lemma'] += ' ' + token['lemma']
else:
current_mwe['text'] = token['text']
current_mwe['lemma'] = token['lemma']
predictions.append([
current_mwe['text'],
current_mwe['lemma'],
token_ner.replace('E-', '')
])
current_mwe = {}
predictions_by_token = {}
for prediction in predictions:
text, lemma, tag = prediction
if text not in predictions_by_token:
predictions_by_token[text] = {
'lemma': lemma,
'tag': [tag]
}
else:
predictions_by_token[text]['tag'].append(tag)
return predictions_by_token
def generate_output_text(predictions: Dict) -> str:
lines = []
for token in sorted(predictions.keys(),
key=lambda x: x.lower()):
prediction = predictions[token]
tags = prediction['tag']
lemma = prediction['lemma']
counts = Counter(tags)
# Get the first most common tag
tag = counts.most_common(1)[0][0]
# If we matched a token with both ORG and PER, return PER
# http://bsnlp.cs.helsinki.fi/System_response_guidelines-1.2.pdf
# (page 3)
if 'ORG' in counts and 'PER' in counts:
tag = 'PER'
# If we matched a token with both ORG and PRO, return PRO
# http://bsnlp.cs.helsinki.fi/System_response_guidelines-1.2.pdf
# (page 3)
elif 'ORG' in counts and 'PRO' in counts:
tag = 'ORG'
# If there is a dot as part of the token, ensure there is no whitespace
# around it (like in 'W . Brytania' vs 'W.Brytania')
if '.' in token:
token = re.sub(r'\s+\.\s+', '.', token)
lemma = re.sub(r'\s+\.\s+', '.', lemma)
lines.append(f'{token}\t{lemma}\t{tag}\tORG-RAND')
return '\n'.join(lines)
@ex.automain
def main(category, lang, save_dir, raw_data_dir, output_data_dir):
symlink_pipeline_part(lang, save_dir, category, '.tagger.mdl')
symlink_pipeline_part(lang, save_dir, category, '.vocabs.json')
symlink_pipeline_part(lang, save_dir, category, '_lemmatizer.pt')
symlink_pipeline_part(lang, save_dir, category, '.tokenizer.mdl')
verify_customized_pipeline(
category=category,
save_dir=save_dir
)
p = Pipeline(lang=category,
cache_dir=save_dir)
for file in Path(raw_data_dir).rglob('*'):
print(file, file.name)
file_text = file.read_text()
lines = file_text.strip().split('\n')
file_id = lines[0]
text = ' '.join(lines[4:])
predictions = predict_on_text(p, text)
output_data_dir = Path(output_data_dir)
# Ensure the output directories exist
output_data_dir.mkdir(parents=True, exist_ok=True)
output_path = output_data_dir / file.name
with output_path.open('w') as f:
f.write(file_id + '\n')
f.write(generate_output_text(predictions))