-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodeltest.py
61 lines (45 loc) · 1.46 KB
/
modeltest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from net import Net
import sys
import pickle
import numpy as np
import ast
def handleNone(string):
if (string.strip() == 'None'):
return None
return string
def main():
query = sys.argv[1]
parent_rule = handleNone(sys.argv[2])
sister_rule = handleNone(sys.argv[3])
var = sys.argv[3:]
# load hidden state of charRNN if necessary
state = torch.load('model_testing-reasoning-17.tch')
state_dict = state['state_dict']
config = state['config']
model = Net(config, load_embeddings = False)
model.load_state_dict(state_dict)
model.eval()
query = model.query_to_tensor(query)
if (parent_rule is not None):
parent_rule = model.rule_to_tensor(parent_rule)
if (sister_rule is not None):
sister_rule = model.rule_to_tensor(sister_rule)
rule = (parent_rule, sister_rule)
#if (var is not None):
var = model.vars_to_tensor(None)
rule_dist, term_dist, var_dist = model(query, rule, var)
rule_dist = rule_dist.detach().numpy()[0][0]
term_dist = term_dist.detach().numpy()[0][0]
#var_dist = [v.detach().numpy()[0][0] for v in var_dist]
results = (rule_dist, term_dist)
"""print('RULEDIST')
print(np.exp(rule_dist))
print('VARDIST')
print(var_dist)
print('TERMDIST')
print(np.exp(term_dist))"""
with open('model_output.pkl', 'wb+') as f:
pickle.dump(results, f, protocol = 2)
if __name__ == "__main__":
main()