-
Notifications
You must be signed in to change notification settings - Fork 173
/
Copy pathdemo.py
executable file
·51 lines (43 loc) · 1.8 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import input
import tensorflow as tf
import coref_model as cm
import util
import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize, word_tokenize
def create_example(text):
raw_sentences = sent_tokenize(text)
sentences = [word_tokenize(s) for s in raw_sentences]
speakers = [["" for _ in sentence] for sentence in sentences]
return {
"doc_key": "nw",
"clusters": [],
"sentences": sentences,
"speakers": speakers,
}
def print_predictions(example):
words = util.flatten(example["sentences"])
for cluster in example["predicted_clusters"]:
print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster]))
def make_predictions(text, model):
example = create_example(text)
tensorized_example = model.tensorize_example(example, is_training=False)
feed_dict = {i:t for i,t in zip(model.input_tensors, tensorized_example)}
_, _, _, mention_starts, mention_ends, antecedents, antecedent_scores, head_scores = session.run(model.predictions + [model.head_scores], feed_dict=feed_dict)
predicted_antecedents = model.get_predicted_antecedents(antecedents, antecedent_scores)
example["predicted_clusters"], _ = model.get_predicted_clusters(mention_starts, mention_ends, predicted_antecedents)
example["top_spans"] = zip((int(i) for i in mention_starts), (int(i) for i in mention_ends))
example["head_scores"] = head_scores.tolist()
return example
if __name__ == "__main__":
config = util.initialize_from_env()
model = cm.CorefModel(config)
with tf.Session() as session:
model.restore(session)
while True:
text = input("Document text: ")
if len(text) > 0:
print_predictions(make_predictions(text, model))