-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeploy.py
128 lines (102 loc) · 4.91 KB
/
deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import pyximport; pyximport.install(pyimport=True)
import argparse
import warnings
from data.local_news_data.baltimore_sun.loader import baltimore_sun_covid_data
from data.local_news_data.cbs.loader import cbs_covid_data
from data.local_news_data.wbaltv.loader import wbaltv_covid_data
from data.template import Dataset, Document
from preprocess.processor import TextProcessor
from search_engine import SearchEngine
from vectorize.gensim import GensimVectorizer
from vectorize.one_hot import OneHotVectorizer
warnings.filterwarnings("ignore")
class CovidDataset(Dataset):
def __init__(self, datasets):
super().__init__()
self.documents = list()
self.load_docs(datasets)
def load_docs(self, datasets):
doc_id = 1
for dataset in datasets:
for document in dataset.documents:
document = Document(doc_id, document.title, document.content,
document.url)
self.documents.append(document)
doc_id += 1
def load_queries(self, filename):
pass
def load_relevant_docs(self, filename):
pass
def arguments_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--personalize', dest="personalize", action='store_true')
parser.set_defaults(personalize=False)
parser.add_argument("--embedding",
default="one-hot",
choices=["one-hot", "word2vec-google-news-300",
"glove-twitter-100", "glove-wiki-gigaword-100",
"glove-wiki-gigaword-200", "fasttext-wiki-news-subwords-300"])
parser.add_argument("--weighting_scheme",
default="tf-idf",
choices=["mean", "tf-idf", "sif", "usif"])
parser.add_argument("--top_k", default="10")
parser.add_argument('--expand_query', dest="expand_query", action='store_true')
parser.set_defaults(expand_query=False)
return parser.parse_args()
def main():
args = arguments_parser()
datasets = [cbs_covid_data, wbaltv_covid_data, baltimore_sun_covid_data]
covid_data = CovidDataset(datasets)
# # (embedding, weighting_scheme)
# best_configs = [
# ("one-hot", "tf-idf"), # 1
# ("one-hot", "mean"), # 2
# ("word2vec-google-news-300", "usif"), # 3
# ("word2vec-google-news-300", "sif"), # 4
# ("word2vec-google-news-300", "mean") # 5
# ]
# for embedding, weighting_scheme in best_configs:
print("####################################################")
print("Model details (embedding, weighting_scheme): ({}, {})".format(
args.embedding, args.weighting_scheme))
text_preprocessor = TextProcessor(re_tokenize=True,
remove_stopwords=True,
stemming=True,
expand_contractions=True,
replace_acronyms=True,
substitute_emoticons=True)
if args.embedding == "one-hot":
text_preprocessor.is_stemming = True
search_engine = SearchEngine(dataset=covid_data,
text_preprocessor=text_preprocessor,
vectorizer=OneHotVectorizer(weighting=args.weighting_scheme,
is_expand_query=args.expand_query),
similarity_metric="cosine")
else:
text_preprocessor.is_stemming = False
search_engine = SearchEngine(dataset=covid_data,
text_preprocessor=text_preprocessor,
vectorizer=GensimVectorizer(model_name=args.embedding,
weighting=args.weighting_scheme,
is_expand_query=args.expand_query),
similarity_metric="cosine")
# for query processing (mispelled query text)
# Note: due to long search engine deploy time reasons,
# we are not performing spell check on the documents
search_engine.text_preprocessor.is_spelling_autocorrect = True
print("Search engine initialized! Try the search engine:\n")
query = input("Query: ")
print()
while query != "exit":
matching_docs = search_engine.search(str(query),
personalize=args.personalize,
top_k=int(args.top_k))[0]
for j, doc in enumerate(matching_docs):
print(str(j+1) + ". " + doc.title.raw)
print("URL: " + doc.url)
print()
print("####################################################\n")
query = input("Query: ")
print()
if __name__ == '__main__':
main()