Skip to content

Commit

Permalink
update eval.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Nov 6, 2018
1 parent 56d399e commit b10e043
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
4 changes: 4 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def data_generator(self, batch_size, seq_len, data_type='train', file_prefix=Non

ep += 1

# leftover
if len(batch_src) > 0:
yield np.array(batch_src).copy(), np.array(batch_tgt).copy()


def recover_sentence(sent_ids, id2word):
"""Convert a list of word ids back to a sentence string.
Expand Down
24 changes: 17 additions & 7 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import click
import numpy as np
from data import DatasetManager, recover_sentence
from data import DatasetManager, recover_sentence, PAD_ID
from transformer import Transformer
from nltk.translate.bleu_score import corpus_bleu

Expand All @@ -21,20 +21,30 @@ def eval(model_name, file_prefix):

cfg = transformer.config

batch_size = cfg['train_params']['batch_size']
seq_len = cfg['train_params']['seq_len'] + 1

dm = DatasetManager(cfg['dataset'])
dm.maybe_download_data_files()
data_iter = dm.data_generator(
cfg['train_params']['batch_size'],
cfg['train_params']['seq_len'] + 1,
data_type='test', file_prefix=file_prefix, epoch=1,
)
batch_size, seq_len, data_type='test', file_prefix=file_prefix, epoch=1)

refs = []
hypos = []
for source_ids, target_ids in data_iter:
valid_size = len(source_ids)
print(source_ids.shape, target_ids.shape)

if valid_size < batch_size:
source_ids = np.array(list(source_ids) + [[PAD_ID] * seq_len] * (batch_size - source_ids.shape[0]))
target_ids = np.array(list(target_ids) + [[PAD_ID] * seq_len] * (batch_size - target_ids.shape[0]))

pred_ids = transformer.predict(source_ids)
refs += [[recover_sentence(sent_ids, dm.target_id2word)] for sent_ids in target_ids]
hypos += [recover_sentence(sent_ids, dm.target_id2word) for sent_ids in pred_ids]

refs += [[recover_sentence(sent_ids, dm.target_id2word)]
for sent_ids in target_ids[:valid_size]]
hypos += [recover_sentence(sent_ids, dm.target_id2word)
for sent_ids in pred_ids[:valid_size]]
print(f"Num. sentences processed: {len(hypos)}", end='\r', flush=True)

print()
Expand Down
1 change: 0 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__)))


class BaseModelMixin:
"""Abstract object representing an Reader model.
Code borrowed from: https://github.com/devsisters/DQN-tensorflow/blob/master/dqn/base.py
Expand Down

0 comments on commit b10e043

Please sign in to comment.