-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhmcan.py
319 lines (271 loc) · 12.1 KB
/
hmcan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from typing import Union, Dict, Any, Optional
import torchvision
from transformers import BertModel
from faknow.model.model import AbstractModel
from faknow.model.layers.transformer import (AddNorm, MultiHeadAttention, FFN,
PositionalEncoding)
class HMCAN(AbstractModel):
r"""
HMCAN: Hierarchical Multi-modal Contextual Attention Network for fake news Detection, SIGIR 2021
paper: https://dl.acm.org/doi/10.1145/3404835.3462871
code: https://github.com/wangjinguang502/HMCAN
"""
def __init__(self,
left_num_layers=2,
left_num_heads=12,
dropout=0.1,
right_num_layers=2,
right_num_heads=12,
alpha=0.7,
pre_trained_bert_name='bert-base-uncased'):
"""
Args:
left_num_layers(int): the numbers of the left Attention&FFN layer
in Contextual Transformer, Default=2.
left_num_heads(int): the numbers of head in
Multi-Head Attention layer(in the left Attention&FFN),
Default=12.
dropout(float): dropout rate, Default=0.1.
right_num_layers(int): the numbers of the right Attention&FFN layer
in Contextual Transformer, Default=2.
right_num_heads(int): the numbers of head in
Multi-Head Attention layer(in the right Attention&FFN),
Default=12.
alpha(float): the weight of the first Attention&FFN layer's output,
Default=0.7.
pre_trained_bert_name(str): the bert name str. default='bert-base-uncased'
"""
super(HMCAN, self).__init__()
self.alpha = alpha
self.output_dims = 768
self.loss_func = nn.CrossEntropyLoss()
# text
self.bert = BertModel.from_pretrained(
pre_trained_bert_name,
output_hidden_states=True).requires_grad_(False)
# image
resnet50 = torchvision.models.resnet50(
weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
for param in resnet50.parameters():
param.requires_grad = False
self.resnet50 = nn.Sequential(*list(resnet50.children())[:-2])
self.image_conv = nn.Conv2d(2048, 768, 4)
self.image_bn = nn.BatchNorm2d(768)
# Contextual Transformer
self.contextual_transform1 = _TextImageTransformer(
left_num_layers, left_num_heads, right_num_layers, right_num_heads,
dropout, self.output_dims)
self.contextual_transform2 = _TextImageTransformer(
left_num_layers, left_num_heads, right_num_layers, right_num_heads,
dropout, self.output_dims)
# Classifier
self.classifier = nn.Sequential(nn.Linear(768 * 6, 256), nn.ReLU(True),
nn.BatchNorm1d(256), nn.Linear(256, 2))
def forward(self, token_id: torch.Tensor, mask: torch.Tensor,
image: torch.Tensor):
"""
Args:
token_id (Tensor): text token ids
image (Tensor): image pixels
mask (torch.Tensor): text masks
Returns:
Tensor: prediction of being fake news, shape=(batch_size, 2)
"""
semantics = self.bert(token_id, attention_mask=mask).hidden_states[
1:] # extract features from all the 12 block in bert-base model
text_embeding = []
for i in range(3):
text_excerpt = semantics[0 + i] + semantics[1 + i] + semantics[
2 + i] + semantics[3 + i]
text_embeding.append(text_excerpt)
image_features = self.resnet50(image)
image_features = F.relu(self.image_bn(
self.image_conv(image_features))) # [batch_size, 768, 4, 4]
image_features = image_features.view(image_features.shape[0],
image_features.shape[1], -1)
image_features = image_features.permute(0, 2,
1) # [batch_size, 16, 768]
mask = torch.ones_like(mask) # ban mask
output = []
for i in range(3):
text_image = self.contextual_transform1(text_embeding[i],
image_features, mask, None)
image_text = self.contextual_transform2(image_features,
text_embeding[i], None, mask)
output_feature = self.alpha * text_image + (
1 - self.alpha) * image_text
output.append(output_feature)
classifier_input = torch.cat((output[0], output[1], output[2]), dim=1)
classifier_output = self.classifier(classifier_input)
return classifier_output
def calculate_loss(self, data: Dict[str, Any]) -> Tensor:
"""
calculate total loss
Args:
data(Dict[str, any]): batch data dict
Returns:
Tensor: total_loss
"""
token_id = data['text']['token_id']
mask = data['text']['mask']
image = data['image']
label = data['label']
output = self.forward(token_id, mask, image)
loss = self.loss_func(output, label)
return loss
def predict(self, data_without_label: Dict[str, Any]) -> Tensor:
"""
predict the probability of being fake news
Args:
data_without_label (Dict[str, Any]): batch data dict
Returns:
Tensor: softmax probability, shape=(batch_size, 2)
"""
token_id = data_without_label['text']['token_id']
mask = data_without_label['text']['mask']
image = data_without_label['image']
pred = self.forward(token_id, mask, image)
pred = torch.softmax(pred, dim=-1)
return pred
class _TextImageTransformer(nn.Module):
"""
Contextual Attention Network of combining image features with text feature
"""
def __init__(self, left_num_layers: int, left_num_heads: int,
right_num_layers: int, right_num_heads: int, dropout: float,
feature_dim: int):
"""
Args:
left_num_layers(int): layer num of the left transformer block.
left_num_heads(int): heads num in the left transformer block.
right_num_layers(int): layer num of the right transformer block.
right_num_heads(int): heads num of the right transformer block.
dropout(float): dropout rate.
feature_dim(int): feature dimension of input.
"""
super().__init__()
self.input_norm = nn.LayerNorm(feature_dim)
input_dim = feature_dim
self.embedding = PositionalEncoding(input_dim, dropout, max_len=1000)
self.transformer1 = _TransformerEncoder(left_num_layers, input_dim,
left_num_heads, input_dim,
dropout)
self.transformer2 = _TransformerEncoder(right_num_layers, input_dim,
right_num_heads, input_dim,
dropout)
def forward(self,
left_features: Tensor,
right_features: Tensor,
left_mask: Optional[Tensor] = None,
right_mask: Optional[Tensor] = None):
"""
Args:
left_features(Tensor): the left transformer's input,
shape=(batch_size, length, embedding_dim).
right_features(Tensor): the right transformer's input,
shape=(batch_size, length, embedding_dim).
left_mask(Union[Tensor, None]): the mask of left input,
shape=(batch_size, ...).
right_mask(Union[Tensor, None]): the mask of right input,
shape=(batch_size, ...)
Returns:
Tensor: shape=(batch_size, 2 * embedding_dim)
"""
left_features = self.input_norm(left_features)
left_features = self.embedding(left_features)
left_features = self.transformer1(left_features, left_features,
left_features, left_mask)
left_pooled = torch.mean(left_features, dim=1)
right_features = self.transformer2(right_features, left_features,
left_features, right_mask)
right_pooled = torch.mean(right_features, dim=1)
return torch.cat([left_pooled, right_pooled], dim=-1)
class _TransformerEncoder(nn.Module):
"""
Transformer for TextImage_Transformer(Contextual Transformer)
"""
def __init__(self, num_layers: int, input_dim: int, num_heads: int,
feature_dims: int, dropout: float):
"""
num_layer(int): layer num of attention block.
input_dim(int): input dimension.
num_heads(int): head num of attention block.
feature_dims(int): dim of attention block's outputs.
dropout(float): dropout rate.
"""
super().__init__()
self.input_dim = input_dim
assert num_layers > 0
self.encoder_layers = nn.ModuleList([
_TransformerEncoderLayer(input_dim, feature_dims, num_heads,
dropout) for _ in range(num_layers)
])
def forward(self, query: Tensor, key: Tensor, value: Tensor,
mask: Union[Tensor, None]):
"""
query(Tensor): shape=(batch_size, q_num, d)
key(Tensor): shape=(batch_size, k-v_num, d)
value(Tensor): shape=(batch_size, k-v_num, v-dim)
mask(Union[Tensor, None]): shape=(batch_size, ...)
"""
if mask is not None:
mask = mask.sum(-1, keepdim=False)
sources = None
for encoder_layer in self.encoder_layers:
sources = encoder_layer(query, key, value, mask)
return sources
class _TransformerEncoderLayer(nn.Module):
"""
Transformer block for each Contextual Transformer
"""
def __init__(self,
input_dim: int,
ffn_hidden_size: int,
head_num: int,
dropout=0.,
bias=False):
"""
Args:
input_dim (int): input dimension.
ffn_hidden_size (int): hidden layer dimension of FFN.
head_num (int): number of attention heads.
dropout (float): dropout rate, default=0.
bias (bool): whether to use bias in Linear layers, default=False.
"""
super(_TransformerEncoderLayer, self).__init__()
assert input_dim % head_num == 0, \
f"model dim {input_dim} not divisible by {head_num} heads"
self.attention = MultiHeadAttention(input_dim,
input_dim,
input_dim,
head_num,
out_size=input_dim // head_num,
dropout=dropout,
bias=bias)
self.addnorm1 = AddNorm(input_dim, dropout)
self.ffn = FFN(input_dim,
ffn_hidden_size,
input_dim,
dropout,
activation=nn.GELU())
self.addnorm2 = AddNorm(input_dim, dropout)
def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
valid_lens: Optional[Tensor] = None):
"""
Args:
query(Tensor): shape=(batch_size, num_steps, input_size)
key(Tensor): shape=(batch_size, k-v_num, d)
value(Tensor): shape=(batch_size, k-v_num, v-dim)
valid_lens (Tensor): shape=(batch_size,), default=None
Returns:
Tensor: shape=(batch_size,) or (batch_size, q_num)
"""
y = self.addnorm1(query, self.attention(query, key, value, valid_lens))
return self.addnorm2(y, self.ffn(y))