Skip to content

Commit

Permalink
tensorboard stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Oct 23, 2018
1 parent 4e4647e commit 56ef323
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
__pycache__/
logs/
checkpoints/
tb/
19 changes: 6 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,24 @@ def train(seq_len=100, d_model=512, n_head=8, batch_size=64, max_steps=100000):
transformer.build_model(id2en, id2vi, **train_params)
transformer.print_trainable_variables()

logger.configure(dir=transformer.log_dir, format_strs=['stdout', 'csv'])

step = 0
test_data_iter = data_generator(batch_size, seq_len, data_dir=data_dir, file_prefix='tst2013')
logger.configure(dir=transformer.log_dir, format_strs=['stdout', 'csv'])

transformer.init()
while step < max_steps:
transformer.init() # step = 0
while transformer.step < max_steps:
for input_ids, target_ids in data_generator(batch_size, seq_len, data_dir=data_dir):
step += 1
logger.logkv('step', step)

meta = transformer.train(input_ids, target_ids)
for k, v in meta.items():
logger.logkv('train_' + k, v)
logger.logkv(k, v)

if step % 100 == 0:
if transformer.step % 100 == 0:
test_inp_ids, test_target_ids = next(test_data_iter)
meta = transformer.evaluate(test_inp_ids, test_target_ids)
for k, v in meta.items():
logger.logkv('test_' + k, v)
logger.dumpkvs()

if step % 1000 == 0:
# Save the model checkpoint.
transformer.save_model(step=step)
transformer.done()


if __name__ == '__main__':
Expand Down
49 changes: 37 additions & 12 deletions transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Implementations that helped me:
https://github.com/Kyubyong/transformer/
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
http://nlp.seas.harvard.edu/2018/04/01/attention.html
"""
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(self, num_heads=8, d_model=512, d_ff=2048,
self.use_label_smoothing = use_label_smoothing

self._is_init = False
self.step = 0 # training step.

# The following variables will be initialized in build_model().
self._input_id2word = None
Expand Down Expand Up @@ -271,41 +273,64 @@ def build_model(self, input_id2word, target_id2word, **train_params):

logits = tf.layers.dense(dec_out, target_vocab) # [batch, target_vocab]
probas = tf.nn.softmax(logits)

self.probas = probas
self._output = tf.argmax(probas, axis=-1, output_type=tf.int32)
print(logits.shape, probas.shape, self._output.shape)
self.probas = probas

self._loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=target))

optim = tf.train.AdamOptimizer(learning_rate=lr)
self._train_op = optim.minimize(self._loss)

with tf.variable_scope(self.model_name + '_summary'):
tf.summary.scalar('loss', self._loss)
self.merged_summary = tf.summary.merge_all()

def init(self):
self.sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
self._is_init = True
self.step = 0

def done(self):
self.writer.close()
self.saver.save() # Final checkpoint.

def train(self, input_ids, target_ids):
assert self._is_init, "Call .init() first."
train_loss, _ = self.sess.run([self.loss, self.train_op], feed_dict={
self.input_ph: input_ids.astype(np.int32),
self.target_ph: target_ids.astype(np.int32),
})
return {'loss': train_loss}
self.step += 1
train_loss, summary, _ = self.sess.run(
[self.loss, self.merged_summary, self.train_op], feed_dict={
self.input_ph: input_ids.astype(np.int32),
self.target_ph: target_ids.astype(np.int32),
})
self.writer.add_summary(summary, global_step=self.step)

if self.step % 1000 == 0:
# Save the model checkpoint every 1000 steps.
self.save_model(step=self.step)

return {'train_loss': train_loss, 'step': self.step}

def predict(self, input_ids):
assert list(input_ids.shape) == self.input_ph.shape.as_list()
batch_size, seq_len = self.input_ph.shape.as_list()
assert input_ids.shape == (batch_size, seq_len)

input_ids = input_ids.astype(np.int32)
pred_ids = np.zeros((batch_size, seq_len), dtype=np.int32)
pred_ids = np.zeros(input_ids.shape, dtype=np.int32)

# Predict one output a time autoregressively.
for i in range(seq_len):
next_pred = self.sess.run(self._output, feed_dict={
self.input_ph: input_ids,
self.target_ph: pred_ids,
})
next_probas, next_pred = self.sess.run(
[self.probas, self._output], feed_dict={
self.input_ph: input_ids,
self.target_ph: pred_ids,
})
# Only update the i-th column in one step.
pred_ids[: i] = next_pred[: i]
#print(f"i={i}", next_probas)
#print(f"i={i}", pred_ids)

return pred_ids

Expand Down

0 comments on commit 56ef323

Please sign in to comment.