-
Notifications
You must be signed in to change notification settings - Fork 126
/
Copy pathmain.py
49 lines (42 loc) · 1.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
######################################################################
######################################################################
# Copyright Tsung-Hsien Wen, Cambridge Dialogue Systems Group, 2016 #
######################################################################
######################################################################
import sys
import os
import numpy as np
from utils.commandparser import RNNLGOptParser
from generator.net import Model
from generator.ngram import Ngram
from generator.knn import KNN
import warnings
warnings.simplefilter("ignore", DeprecationWarning)
if __name__ == '__main__':
args = RNNLGOptParser()
config = args.config
if args.mode=='knn':
# knn
knn = KNN(config,args)
knn.testKNN()
elif args.mode=='ngram':
# ngram case
ngram = Ngram(config,args)
ngram.testNgram()
else:
# NN case
model = Model(config,args)
if args.mode=='train' or args.mode=='adapt':
model.trainNet()
elif args.mode=='test':
model.testNet()
# not supported yet
"""
elif args.mode=='realtime':
while True:
dact=raw_input('Target dialogue act: ')
sents, errs = model.genSent(dact)
for s in sents:
print s
print
"""