-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathchatlog.py
52 lines (40 loc) · 1.67 KB
/
chatlog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# preprocessed data
from datasets.twitter import data
import data_utils
metadata, idx_q, idx_a = data.load_data(PATH='datasets/twitter/')
(trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_q, idx_a)
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
batch_size = 16
xvocab_size = len(metadata['idx2w'])
yvocab_size = xvocab_size
emb_dim = 1024
import seq2seq_wrapper
import importlib
importlib.reload(seq2seq_wrapper)
model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len,
yseq_len=yseq_len,
xvocab_size=xvocab_size,
yvocab_size=yvocab_size,
ckpt_path='ckpt/twitter/',
emb_dim=emb_dim,
num_layers=3
)
val_batch_gen = data_utils.rand_batch_gen(validX, validY, 256)
test_batch_gen = data_utils.rand_batch_gen(testX, testY, 256)
train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)
sess = model.restore_last_session()
input_ = test_batch_gen.__next__()[0]
output = model.predict(sess, input_)
print(output.shape)
output_file = open("testtraining.txt", 'a')
replies = []
for ii, oi in zip(input_.T, output):
q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ')
decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ')
if decoded.count('unk') == 0:
if decoded not in replies:
line = 'q : [{0}]; a : [{1}]'.format(q, ' '.join(decoded))
print(line)
output_file.write(str(line)+"\n")
replies.append(decoded)