-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
129 lines (104 loc) · 4.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 5
nmt.py: NMT Model
Pencheng Yin <[email protected]>
Sahil Chopra <[email protected]>
"""
import math
import numpy as np
def pad_sents_char(sents, char_pad_token):
""" Pad list of sentences according to the longest sentence in the batch and max_word_length.
@param sents (list[list[list[int]]]): list of sentences, result of `words2charindices()`
from `vocab.py`
@param char_pad_token (int): index of the character-padding token
@returns sents_padded (list[list[list[int]]]): list of sentences where sentences/words shorter
than the max length sentence/word are padded out with the appropriate pad token, such that
each sentence in the batch now has same number of words and each word has an equal
number of characters
Output shape: (batch_size, max_sentence_length, max_word_length)
"""
# Words longer than 21 characters should be truncated
max_word_length = 21
### YOUR CODE HERE for part 1f
### TODO:
### Perform necessary padding to the sentences in the batch similar to the pad_sents()
### method below using the padding character from the arguments. You should ensure all
### sentences have the same number of words and each word has the same number of
### characters.
### Set padding words to a `max_word_length` sized vector of padding characters.
###
### You should NOT use the method `pad_sents()` below because of the way it handles
### padding and unknown words.
# This method pads at character and word level. All words are padded/truncated to max word length max_word, and
# all sentences are padded to the lengths of the longest sentence in the batch.
padding_word = [char_pad_token]*max_word_length
max_length = max([len(sent) for sent in sents])
sents_padded = []
for sent in sents:
current_sentence = []
for word in sent:
word_length = len(word)
if word_length < max_word_length:
w = word + [char_pad_token] * (max_word_length - word_length)
elif word_length > max_word_length:
w = word[:max_word_length] # trim in place
else: # word is just right
w = word
current_sentence.append(w)
if len(current_sentence) < max_length:
current_sentence.extend([padding_word] * (max_length - len(sent)))
sents_padded.append(current_sentence)
### END YOUR CODE
return sents_padded
def pad_sents(sents, pad_token):
""" Pad list of sentences according to the longest sentence in the batch.
@param sents (list[list[int]]): list of sentences, where each sentence
is represented as a list of words
@param pad_token (int): padding token
@returns sents_padded (list[list[int]]): list of sentences where sentences shorter
than the max length sentence are padded out with the pad_token, such that
each sentences in the batch now has equal length.
Output shape: (batch_size, max_sentence_length)
"""
sents_padded = []
### COPY OVER YOUR CODE FROM ASSIGNMENT 4
max_length = max([len(sent) for sent in sents])
for sent in sents:
if len(sent) < max_length:
sent.extend([pad_token]*(max_length-len(sent)))
sents_padded.append(sent)
### END YOUR CODE FROM ASSIGNMENT 4
return sents_padded
def read_corpus(file_path, source):
""" Read file, where each sentence is dilineated by a `\n`.
@param file_path (str): path to file containing corpus
@param source (str): "tgt" or "src" indicating whether text
is of the source language or target language
"""
data = []
for line in open(file_path):
sent = line.strip().split(' ')
# only append <s> and </s> to the target sentence
if source == 'tgt':
sent = ['<s>'] + sent + ['</s>']
data.append(sent)
return data
def batch_iter(data, batch_size, shuffle=False):
""" Yield batches of source and target sentences reverse sorted by length (largest to smallest).
@param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
@param batch_size (int): batch size
@param shuffle (boolean): whether to randomly shuffle the dataset
"""
batch_num = math.ceil(len(data) / batch_size)
index_array = list(range(len(data)))
if shuffle:
np.random.shuffle(index_array)
for i in range(batch_num):
indices = index_array[i * batch_size: (i + 1) * batch_size]
examples = [data[idx] for idx in indices]
examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
src_sents = [e[0] for e in examples]
tgt_sents = [e[1] for e in examples]
yield src_sents, tgt_sents