-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbasic.py
116 lines (97 loc) · 4.32 KB
/
basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Basic or helper implementation."""
import torch
from torch import nn
from torch.nn import functional
def convert_to_one_hot(indices, num_classes):
"""
Args:
indices (tensor): A vector containing indices,
whose size is (batch_size,).
num_classes (tensor): The number of classes, which would be
the second dimension of the resulting one-hot matrix.
Returns:
result: The one-hot matrix of size (batch_size, num_classes).
"""
batch_size = indices.size(0)
indices = indices.unsqueeze(1)
one_hot = indices.new_zeros(batch_size, num_classes).scatter_(1, indices, 1)
return one_hot
def masked_softmax(logits, mask=None):
eps = 1e-20
probs = functional.softmax(logits, dim=1)
if mask is not None:
mask = mask.float()
probs = probs * mask + eps
probs = probs / probs.sum(1, keepdim=True)
return probs
def greedy_select(logits, mask=None):
probs = masked_softmax(logits=logits, mask=mask)
one_hot = convert_to_one_hot(indices=probs.max(1)[1],
num_classes=logits.size(1))
return one_hot
def st_gumbel_softmax(logits, temperature=1.0, mask=None):
"""
Return the result of Straight-Through Gumbel-Softmax Estimation.
It approximates the discrete sampling via Gumbel-Softmax trick
and applies the biased ST estimator.
In the forward propagation, it emits the discrete one-hot result,
and in the backward propagation it approximates the categorical
distribution via smooth Gumbel-Softmax distribution.
Args:
logits (tensor): A un-normalized probability values,
which has the size (batch_size, num_classes)
temperature (float): A temperature parameter. The higher
the value is, the smoother the distribution is.
mask (tensor, optional): If given, it masks the softmax
so that indices of '0' mask values are not selected.
The size is (batch_size, num_classes).
Returns:
y: The sampled output, which has the property explained above.
"""
eps = 1e-20
u = logits.data.new(*logits.size()).uniform_()
gumbel_noise = -torch.log(-torch.log(u + eps) + eps)
y = logits + gumbel_noise
y = masked_softmax(logits=y / temperature, mask=mask)
y_argmax = y.max(1)[1]
y_hard = convert_to_one_hot(indices=y_argmax, num_classes=y.size(1)).float()
y = (y_hard - y).detach() + y
return y
def sequence_mask(sequence_length, max_length=None):
if max_length is None:
max_length = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_length).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_length)
seq_range_expand = seq_range_expand.to(sequence_length)
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
return seq_range_expand < seq_length_expand
def reverse_padded_sequence(inputs, lengths, batch_first=False):
"""Reverses sequences according to their lengths.
Inputs should have size ``T x B x *`` if ``batch_first`` is False, or
``B x T x *`` if True. T is the length of the longest sequence (or larger),
B is the batch size, and * is any number of dimensions (including 0).
Arguments:
inputs (tensor): padded batch of variable length sequences.
lengths (list[int]): list of sequence lengths
batch_first (bool, optional): if True, inputs should be B x T x *.
Returns:
A tensor with the same size as inputs, but with each sequence
reversed according to its length.
"""
if not batch_first:
inputs = inputs.transpose(0, 1)
if inputs.size(0) != len(lengths):
raise ValueError('inputs incompatible with lengths.')
reversed_indices = [list(range(inputs.size(1)))
for _ in range(inputs.size(0))]
for i, length in enumerate(lengths):
if length > 0:
reversed_indices[i][:length] = reversed_indices[i][length-1::-1]
reversed_indices = (torch.LongTensor(reversed_indices).unsqueeze(2)
.expand_as(inputs))
reversed_indices = reversed_indices.to(inputs)
reversed_inputs = torch.gather(inputs, 1, reversed_indices)
if not batch_first:
reversed_inputs = reversed_inputs.transpose(0, 1)
return reversed_inputs