Skip to content

Commit

Permalink
correct mask
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Oct 24, 2018
1 parent 56ef323 commit c6d3415
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 72 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.idea/
__pycache__/
.pytest_cache/
logs/
checkpoints/
tb/
19 changes: 12 additions & 7 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import numpy as np

# IDs of special characters.
EMPTY_ID = 0
PAD_ID = 0
UNKNOWN_ID = 1
START_ID = 2
END_ID = 3


def download_data_from_url(download_url, data_dir):
Expand Down Expand Up @@ -45,22 +47,25 @@ def load_vocab(vocab_file):
# <unk>: unknown word.
# <s>: start of a sentence.
# </s>: # end of a sentence.
# In addition, we add <empty> as a place holder for empty space.
# In addition, we add <pad> as a place holder for a padding space.
words = list(map(lambda w: w.strip().lower(), open(vocab_file)))
words.insert(0, '<empty>')
words.insert(0, '<pad>')
words = words[:4] + list(set(words[4:])) # Keep the special characters on top.
word2id = {word: i for i, word in enumerate(words)}
id2word = {i: word for i, word in enumerate(words)}

assert id2word[EMPTY_ID] == '<empty>'
assert id2word[PAD_ID] == '<pad>'
assert id2word[UNKNOWN_ID] == '<unk>'
assert id2word[START_ID] == '<s>'
assert id2word[END_ID] == '</s>'

return word2id, id2word


def sentence_pair_iterator(file1, file2, word2id1, word2id2, seq_len):
"""
The sentence is discarded if it is longer than `seq_len`; otherwise we pad it with
'<empty>' to make it to have the exact length `seq_len`.
'<pad>' to make it to have the exact length `seq_len`.
Args:
file1 (str): training data in language 1.
Expand All @@ -76,8 +81,8 @@ def sentence_pair_iterator(file1, file2, word2id1, word2id2, seq_len):
def parse_line(line, word2id):
line = line.strip().lower().split()
word_ids = [word2id.get(w, UNKNOWN_ID) for w in line]
# If the sentence is not long enough, pad empty symbols.
word_ids += [EMPTY_ID] * max(0, seq_len - len(word_ids))
# If the sentence is not long enough, extend with '<pad>' symbols.
word_ids += [PAD_ID] * max(0, seq_len - len(word_ids))
return word_ids

for l1, l2 in zip(open(file1), open(file2)):
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def train(seq_len=100, d_model=512, n_head=8, batch_size=64, max_steps=100000):
model_name=model_name,
tf_sess_config=tf_sess_config
)
transformer.build_model(id2en, id2vi, **train_params)
transformer.build_model(id2en, id2vi, PAD_ID, **train_params)
transformer.print_trainable_variables()

test_data_iter = data_generator(batch_size, seq_len, data_dir=data_dir, file_prefix='tst2013')
Expand Down
Loading

0 comments on commit c6d3415

Please sign in to comment.