Skip to content

Commit

Permalink
fix model loading bug + add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Nov 7, 2018
1 parent b10e043 commit ae29030
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
1 change: 0 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def eval(model_name, file_prefix):
hypos = []
for source_ids, target_ids in data_iter:
valid_size = len(source_ids)
print(source_ids.shape, target_ids.shape)

if valid_size < batch_size:
source_ids = np.array(list(source_ids) + [[PAD_ID] * seq_len] * (batch_size - source_ids.shape[0]))
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from baselines import logger
from data import *
from transformer import *
from utils import print_trainable_variables


@click.command()
Expand Down Expand Up @@ -49,7 +50,7 @@ def train(seq_len, d_model, d_ff, n_head, batch_size, max_steps, dataset):
tf_sess_config=tf_sess_config
)
transformer.build_model(dataset, dm.source_id2word, dm.target_id2word, PAD_ID, **train_params)
transformer.print_trainable_variables()
print_trainable_variables()

train_data_iter = dm.data_generator(batch_size, seq_len + 1, data_type='train')
test_data_iter = dm.data_generator(batch_size, seq_len + 1, data_type='test')
Expand Down
11 changes: 3 additions & 8 deletions transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from utils import BaseModelMixin, REPO_ROOT
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from baselines.common.tf_util import display_var_info
from data import recover_sentence, START_ID, PAD_ID


Expand Down Expand Up @@ -208,16 +207,16 @@ def load_model(cls, model_name, is_training=False):
model.build_model(cfg['dataset'], cfg['input_id2word'], cfg['target_id2word'],
pad_id=cfg['pad_id'], is_training=is_training,
**cfg['train_params'])
model.sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

# model.sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
model.load_checkpoint()
return model

def embedding(self, inp, vocab_size, zero_pad=True):
"""When the `zero_pad` flag is on, the first row in the embedding lookup table is
fixed to be an all-zero vector, corresponding to the '<pad>' symbol."""
embed_size = self.d_model
embed_lookup = tf.get_variable("embed_lookup", [vocab_size, embed_size], tf.float32)
embed_lookup = tf.get_variable("embed_lookup", [vocab_size, embed_size], tf.float32,
initializer=tf.contrib.layers.xavier_initializer())

if zero_pad:
assert self._pad_id == 0
Expand Down Expand Up @@ -563,10 +562,6 @@ def evaluate(self, input_ids, target_ids):

# ============================= Utils ===============================

def print_trainable_variables(self):
t_vars = tf.trainable_variables(scope=self.model_name)
display_var_info(t_vars)

def _check_variable(self, v, name):
if v is None:
raise ValueError(f"Call build_model() to initialize {name}.")
Expand Down
19 changes: 15 additions & 4 deletions transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
http://lilianweng.github.io/lil-log
Oct 2018
"""
import os
import numpy as np
import shutil
import numpy as np
import tensorflow as tf

from data import DatasetManager, PAD_ID
from transformer import Transformer
from utils import print_trainable_variables


class TransformerTest(tf.test.TestCase):
def setUp(self):
self.t = Transformer(model_name='test')
self.t = Transformer(model_name='test', num_heads=4, d_model=64, d_ff=128,
num_enc_layers=2, num_dec_layers=2)
self.batch_size = 4
self.seq_len = 5
self.raw_input_ph = tf.placeholder(tf.int32, shape=(self.batch_size, self.seq_len))
Expand All @@ -34,14 +36,23 @@ def test_build_and_load_model(self):
dm.load_vocab()

self.t.build_model('iwslt15', dm.source_id2word, dm.target_id2word, PAD_ID)
print_trainable_variables()
self.t.init()
value_dict = self.t.get_variable_values()

tf.reset_default_graph()
model = Transformer.load_model('test')
model.print_trainable_variables()
out = model.predict(np.zeros(model.raw_input_ph.shape))
assert out.shape == model.raw_target_ph.shape

value_dict2 = model.get_variable_values()
for k in value_dict2:
print("\n*************************************")
print(k)
print(value_dict[k])
print(value_dict2[k])
assert np.allclose(value_dict[k], value_dict2[k])

def test_construct_padding_mask(self):
with self.test_session() as sess:
mask_ph = self.t.construct_padding_mask(self.raw_input_ph)
Expand Down
33 changes: 19 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import os

import tensorflow as tf
from baselines.common.tf_util import display_var_info
from baselines.common.console_util import colorize

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__)))


def print_trainable_variables():
display_var_info(tf.trainable_variables())


class BaseModelMixin:
"""Abstract object representing an Reader model.
Code borrowed from: https://github.com/devsisters/DQN-tensorflow/blob/master/dqn/base.py
with some modifications.
"""

def __init__(self, model_name, saver_max_to_keep=5, tf_sess_config=None):
print("Model name:", model_name)

def __init__(self, model_name, tf_sess_config=None):
self._saver = None
self._saver_max_to_keep = saver_max_to_keep
self._writer = None
self._model_name = model_name
self._sess = None
Expand All @@ -28,13 +31,10 @@ def __init__(self, model_name, saver_max_to_keep=5, tf_sess_config=None):
}
self.tf_sess_config = tf_sess_config

def scope_vars(self, scope):
res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
assert len(res) > 0
print("Variables in scope '%s'" % scope)
for v in res:
print("\t" + str(v))
return res
def get_variable_values(self):
t_vars = tf.trainable_variables()
vals = self.sess.run(t_vars)
return {v.name: value for v, value in zip(t_vars, vals)}

def save_checkpoint(self, step=None):
print(colorize(" [*] Saving checkpoints...", "green"))
Expand All @@ -44,10 +44,11 @@ def save_checkpoint(self, step=None):
def load_checkpoint(self):
print(colorize(" [*] Loading checkpoints...", "green"))
ckpt_path = tf.train.latest_checkpoint(self.checkpoint_dir)
print(self.checkpoint_dir, ckpt_path)
print(self.checkpoint_dir)
print("ckpt_path:", ckpt_path)

if ckpt_path:
self._saver = tf.train.import_meta_graph(ckpt_path + '.meta')
# self._saver = tf.train.import_meta_graph(ckpt_path + '.meta')
self.saver.restore(self.sess, ckpt_path)
print(colorize(" [*] Load SUCCESS: %s" % ckpt_path, "green"))
return True
Expand All @@ -68,6 +69,10 @@ def log_dir(self):
def checkpoint_dir(self):
return self._get_dir('checkpoints')

@property
def model_dir(self):
return self._get_dir('models')

@property
def tb_dir(self):
# tensorboard
Expand All @@ -81,7 +86,7 @@ def model_name(self):
@property
def saver(self):
if self._saver is None:
self._saver = tf.train.Saver(max_to_keep=self._saver_max_to_keep)
self._saver = tf.train.Saver(max_to_keep=5)
return self._saver

@property
Expand Down

0 comments on commit ae29030

Please sign in to comment.