forked from thuiar/MMSA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLMF.py
108 lines (92 loc) · 4.76 KB
/
LMF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
paper: Efficient Low-rank Multimodal Fusion with Modality-Specific Factors
From: https://github.com/Justin1904/Low-rank-Multimodal-Fusion
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_normal_
from models.subNets.FeatureNets import SubNet, TextSubNet
__all__ = ['LMF']
class LMF(nn.Module):
'''
Low-rank Multimodal Fusion
'''
def __init__(self, args):
'''
Args:
input_dims - a length-3 tuple, contains (audio_dim, video_dim, text_dim)
hidden_dims - another length-3 tuple, hidden dims of the sub-networks
text_out - int, specifying the resulting dimensions of the text subnetwork
dropouts - a length-4 tuple, contains (audio_dropout, video_dropout, text_dropout, post_fusion_dropout)
output_dim - int, specifying the size of output
rank - int, specifying the size of rank in LMF
Output:
(return value in forward) a scalar value between -3 and 3
'''
super(LMF, self).__init__()
# dimensions are specified in the order of audio, video and text
self.text_in, self.audio_in, self.video_in = args.feature_dims
self.text_hidden, self.audio_hidden, self.video_hidden = args.hidden_dims
self.text_out= self.text_hidden // 2
self.output_dim = args.num_classes if args.train_mode == "classification" else 1
self.rank = args.rank
self.audio_prob, self.video_prob, self.text_prob, self.post_fusion_prob = args.dropouts
# define the pre-fusion subnetworks
self.audio_subnet = SubNet(self.audio_in, self.audio_hidden, self.audio_prob)
self.video_subnet = SubNet(self.video_in, self.video_hidden, self.video_prob)
self.text_subnet = TextSubNet(self.text_in, self.text_hidden, self.text_out, dropout=self.text_prob)
# define the post_fusion layers
self.post_fusion_dropout = nn.Dropout(p=self.post_fusion_prob)
# self.post_fusion_layer_1 = nn.Linear((self.text_out + 1) * (self.video_hidden + 1) * (self.audio_hidden + 1), self.post_fusion_dim)
self.audio_factor = Parameter(torch.Tensor(self.rank, self.audio_hidden + 1, self.output_dim))
self.video_factor = Parameter(torch.Tensor(self.rank, self.video_hidden + 1, self.output_dim))
self.text_factor = Parameter(torch.Tensor(self.rank, self.text_out + 1, self.output_dim))
self.fusion_weights = Parameter(torch.Tensor(1, self.rank))
self.fusion_bias = Parameter(torch.Tensor(1, self.output_dim))
# init teh factors
xavier_normal_(self.audio_factor)
xavier_normal_(self.video_factor)
xavier_normal_(self.text_factor)
xavier_normal_(self.fusion_weights)
self.fusion_bias.data.fill_(0)
def forward(self, text_x, audio_x, video_x):
'''
Args:
audio_x: tensor of shape (batch_size, audio_in)
video_x: tensor of shape (batch_size, video_in)
text_x: tensor of shape (batch_size, sequence_len, text_in)
'''
audio_x = audio_x.squeeze(1)
video_x = video_x.squeeze(1)
audio_h = self.audio_subnet(audio_x)
video_h = self.video_subnet(video_x)
text_h = self.text_subnet(text_x)
batch_size = audio_h.data.shape[0]
# next we perform low-rank multimodal fusion
# here is a more efficient implementation than the one the paper describes
# basically swapping the order of summation and elementwise product
# next we perform "tensor fusion", which is essentially appending 1s to the tensors and take Kronecker product
add_one = torch.ones(size=[batch_size, 1], requires_grad=False).type_as(audio_h).to(text_x.device)
_audio_h = torch.cat((add_one, audio_h), dim=1)
_video_h = torch.cat((add_one, video_h), dim=1)
_text_h = torch.cat((add_one, text_h), dim=1)
fusion_audio = torch.matmul(_audio_h, self.audio_factor)
fusion_video = torch.matmul(_video_h, self.video_factor)
fusion_text = torch.matmul(_text_h, self.text_factor)
fusion_zy = fusion_audio * fusion_video * fusion_text
# output = torch.sum(fusion_zy, dim=0).squeeze()
# use linear transformation instead of simple summation, more flexibility
output = torch.matmul(self.fusion_weights, fusion_zy.permute(1, 0, 2)).squeeze() + self.fusion_bias
output = output.view(-1, self.output_dim)
res = {
'Feature_t': text_h,
'Feature_a': audio_h,
'Feature_v': video_h,
'Feature_f': fusion_zy.permute(1, 0, 2).squeeze(),
'M': output
}
return res