Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Nov 7, 2018
1 parent 6957ed5 commit 21cb7a3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
40 changes: 36 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# WIP
# Transformer
Implementation of the *Transformer* model in the paper:

> Ashish Vaswani, et al. ["Attention is all you need."](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) NIPS 2017.
Expand All @@ -14,11 +14,43 @@ Implementations that helped me:
* http://nlp.seas.harvard.edu/2018/04/01/attention.html


A couple of tricking points in the implementation.
Train a model:

```bash
# Check the help message:

$ python train.py --help

Usage: train.py [OPTIONS]

Options:
--seq-len INTEGER Input sequence length. [default: 20]
--d-model INTEGER d_model [default: 512]
--d-ff INTEGER d_ff [default: 2048]
--n-head INTEGER n_head [default: 8]
--batch-size INTEGER Batch size [default: 128]
--max-steps INTEGER Max train steps. [default: 300000]
--dataset [iwslt15|wmt14|wmt15]
Which translation dataset to use. [default:
iwslt15]
--help Show this message and exit.

# Train a model on dataset WMT14:

$ python train.py --dataset wmt14
```

Evaluate a trained model:
```
# Let's say, the model is saved in folder `transformer-wmt14-seq20-d512-head8-1541573730` in checkpoints folder.
python eval.py transformer-wmt14-seq20-d512-head8-1541573730
```
With the default config, this implementation gets BLEU ~ 20 on wmt14 test set.


\[WIP\] A couple of tricking points in the implementation.

* How to construct the mask correctly?
* How to correctly shift decoder input (as training input) and decoder target (as ground truth in the loss function)?
* How to make the prediction in an autoregressive way?
* Keeping the embedding of `<pad>` as a constant zero vector is sorta important.


5 changes: 3 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ def eval(model_name, file_prefix):
cfg = transformer.config
batch_size = cfg['train_params']['batch_size']
seq_len = cfg['train_params']['seq_len'] + 1
print(f'batch_size:{batch_size} seq_len:{seq_len}')

dm = DatasetManager(cfg['dataset'])
dm.maybe_download_data_files()
data_iter = dm.data_generator(
batch_size, seq_len, data_type='test', file_prefix=file_prefix, epoch=1)
data_iter = dm.data_generator(batch_size, seq_len, data_type='test',
file_prefix=file_prefix, epoch=1)

refs = []
hypos = []
Expand Down

0 comments on commit 21cb7a3

Please sign in to comment.