From 21cb7a3896c97b681b3197271b1a680d6cf9b29c Mon Sep 17 00:00:00 2001 From: Lilian Weng Date: Tue, 6 Nov 2018 23:15:20 -0800 Subject: [PATCH] readme --- README.md | 40 ++++++++++++++++++++++++++++++++++++---- eval.py | 5 +++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 3c5ef89..9e64eac 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# WIP +# Transformer Implementation of the *Transformer* model in the paper: > Ashish Vaswani, et al. ["Attention is all you need."](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) NIPS 2017. @@ -14,11 +14,43 @@ Implementations that helped me: * http://nlp.seas.harvard.edu/2018/04/01/attention.html -A couple of tricking points in the implementation. +Train a model: + +```bash +# Check the help message: + +$ python train.py --help + +Usage: train.py [OPTIONS] + +Options: + --seq-len INTEGER Input sequence length. [default: 20] + --d-model INTEGER d_model [default: 512] + --d-ff INTEGER d_ff [default: 2048] + --n-head INTEGER n_head [default: 8] + --batch-size INTEGER Batch size [default: 128] + --max-steps INTEGER Max train steps. [default: 300000] + --dataset [iwslt15|wmt14|wmt15] + Which translation dataset to use. [default: + iwslt15] + --help Show this message and exit. + +# Train a model on dataset WMT14: + +$ python train.py --dataset wmt14 +``` + +Evaluate a trained model: +``` +# Let's say, the model is saved in folder `transformer-wmt14-seq20-d512-head8-1541573730` in checkpoints folder. +python eval.py transformer-wmt14-seq20-d512-head8-1541573730 +``` +With the default config, this implementation gets BLEU ~ 20 on wmt14 test set. + + +\[WIP\] A couple of tricking points in the implementation. * How to construct the mask correctly? * How to correctly shift decoder input (as training input) and decoder target (as ground truth in the loss function)? * How to make the prediction in an autoregressive way? * Keeping the embedding of `` as a constant zero vector is sorta important. - - \ No newline at end of file diff --git a/eval.py b/eval.py index 6655f4c..b3d90da 100644 --- a/eval.py +++ b/eval.py @@ -21,11 +21,12 @@ def eval(model_name, file_prefix): cfg = transformer.config batch_size = cfg['train_params']['batch_size'] seq_len = cfg['train_params']['seq_len'] + 1 + print(f'batch_size:{batch_size} seq_len:{seq_len}') dm = DatasetManager(cfg['dataset']) dm.maybe_download_data_files() - data_iter = dm.data_generator( - batch_size, seq_len, data_type='test', file_prefix=file_prefix, epoch=1) + data_iter = dm.data_generator(batch_size, seq_len, data_type='test', + file_prefix=file_prefix, epoch=1) refs = [] hypos = []