-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathck12_wiki_predict.py
89 lines (63 loc) · 2.69 KB
/
ck12_wiki_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import utils
import numpy as np
import pandas as pd
#urls to get toppics
ck12_url_topic = ['https://www.ck12.org/earth-science/', 'http://www.ck12.org/life-science/',
'http://www.ck12.org/physical-science/', 'http://www.ck12.org/biology/',
'http://www.ck12.org/chemistry/', 'http://www.ck12.org/physics/']
wiki_docs_dir = 'data/wiki_data'
def get_wiki_docs():
# get keywords
ck12_keywords = set()
for url_topic in ck12_url_topic:
keywords= utils.get_keyword_from_url_topic(url_topic)
for kw in keywords:
ck12_keywords.add(kw)
#get and save wiki docs
utils.get_save_wiki_docs(ck12_keywords, wiki_docs_dir)
def predict(data, docs_per_q):
#index docs
docs_tf, words_idf = utils.get_docstf_idf(wiki_docs_dir)
res = []
for index, row in data.iterrows():
#get answers words
w_A = set(utils.tokenize(row['answerA']))
w_B = set(utils.tokenize(row['answerB']))
w_C = set(utils.tokenize(row['answerC']))
w_D = set(utils.tokenize(row['answerD']))
sc_A = 0
sc_B = 0
sc_C = 0
sc_D = 0
q = row['question']
for d in zip(*utils.get_docs_importance_for_question(q, docs_tf, words_idf, docs_per_q))[0]:
for w in w_A:
if w in docs_tf[d]:
sc_A += 1. * docs_tf[d][w] * words_idf[w]
for w in w_B:
if w in docs_tf[d]:
sc_B += 1. * docs_tf[d][w] * words_idf[w]
for w in w_C:
if w in docs_tf[d]:
sc_C += 1. * docs_tf[d][w] * words_idf[w]
for w in w_D:
if w in docs_tf[d]:
sc_D += 1. * docs_tf[d][w] * words_idf[w]
res.append(['A','B','C','D'][np.argmax([sc_A, sc_B, sc_C, sc_D])])
return res
if __name__ == '__main__':
#parsing input arguments
parser = argparse.ArgumentParser()
parser.add_argument('--fname', type=str, default='validation_set.tsv', help='file name with data')
parser.add_argument('--docs_per_q', type=int, default= 10, help='number of docs to consider when ranking quesitons')
parser.add_argument('--get_data', type=int, default= 0, help='flag to get wiki data for IR')
args = parser.parse_args()
if args.get_data:
get_wiki_docs()
#read data
data = pd.read_csv('data/' + args.fname, sep = '\t' )
#predict
res = predict(data, args.docs_per_q)
#save result
pd.DataFrame({'id': list(data['id']), 'correctAnswer': res})[['id', 'correctAnswer']].to_csv("prediction.csv", index = False)