-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathmodels.py
223 lines (183 loc) · 7.79 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#-------------------------------------
# Project: Transductive Propagation Network for Few-shot Learning
# Date: 2019.1.11
# Author: Yanbin Liu
# All Rights Reserved
#-------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class CNNEncoder(nn.Module):
"""Encoder for feature embedding"""
def __init__(self, args):
super(CNNEncoder, self).__init__()
self.args = args
h_dim, z_dim = args['h_dim'], args['z_dim']
self.layer1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer2 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer3 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer4 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2))
def forward(self,x):
"""x: bs*3*84*84 """
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
class RelationNetwork(nn.Module):
"""Graph Construction Module"""
def __init__(self):
super(RelationNetwork, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, padding=1))
self.layer2 = nn.Sequential(
nn.Conv2d(64,1,kernel_size=3,padding=1),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, padding=1))
self.fc3 = nn.Linear(2*2, 8)
self.fc4 = nn.Linear(8, 1)
self.m0 = nn.MaxPool2d(2) # max-pool without padding
self.m1 = nn.MaxPool2d(2, padding=1) # max-pool with padding
def forward(self, x, rn):
x = x.view(-1,64,5,5)
out = self.layer1(x)
out = self.layer2(out)
# flatten
out = out.view(out.size(0),-1)
out = F.relu(self.fc3(out))
out = self.fc4(out) # no relu
out = out.view(out.size(0),-1) # bs*1
return out
class Prototypical(nn.Module):
"""Main Module for prototypical networlks"""
def __init__(self, args):
super(Prototypical, self).__init__()
self.im_width, self.im_height, self.channels = list(map(int, args['x_dim'].split(',')))
self.h_dim, self.z_dim = args['h_dim'], args['z_dim']
self.args = args
self.encoder = CNNEncoder(args)
def forward(self, inputs):
"""
inputs are preprocessed
support: (N_way*N_shot)x3x84x84
query: (N_way*N_query)x3x84x84
s_labels: (N_way*N_shot)xN_way, one-hot
q_labels: (N_way*N_query)xN_way, one-hot
"""
[support, s_labels, query, q_labels] = inputs
num_classes = s_labels.shape[1]
num_support = int(s_labels.shape[0] / num_classes)
num_queries = int(query.shape[0] / num_classes)
inp = torch.cat((support,query), 0)
emb = self.encoder(inp) # 80x64x5x5
emb_s, emb_q = torch.split(emb, [num_classes*num_support, num_classes*num_queries], 0)
emb_s = emb_s.view(num_classes, num_support, 1600).mean(1)
emb_q = emb_q.view(-1, 1600)
emb_s = torch.unsqueeze(emb_s,0) # 1xNxD
emb_q = torch.unsqueeze(emb_q,1) # Nx1xD
dist = ((emb_q-emb_s)**2).mean(2) # NxNxD -> NxN
ce = nn.CrossEntropyLoss().cuda(0)
loss = ce(-dist, torch.argmax(q_labels,1))
## acc
pred = torch.argmax(-dist,1)
gt = torch.argmax(q_labels,1)
correct = (pred==gt).sum()
total = num_queries*num_classes
acc = 1.0 * correct.float() / float(total)
return loss, acc
class LabelPropagation(nn.Module):
"""Label Propagation"""
def __init__(self, args):
super(LabelPropagation, self).__init__()
self.im_width, self.im_height, self.channels = list(map(int, args['x_dim'].split(',')))
self.h_dim, self.z_dim = args['h_dim'], args['z_dim']
self.args = args
self.encoder = CNNEncoder(args)
self.relation = RelationNetwork()
if args['rn'] == 300: # learned sigma, fixed alpha
self.alpha = torch.tensor([args['alpha']], requires_grad=False).cuda(0)
elif args['rn'] == 30: # learned sigma, learned alpha
self.alpha = nn.Parameter(torch.tensor([args['alpha']]).cuda(0), requires_grad=True)
def forward(self, inputs):
"""
inputs are preprocessed
support: (N_way*N_shot)x3x84x84
query: (N_way*N_query)x3x84x84
s_labels: (N_way*N_shot)xN_way, one-hot
q_labels: (N_way*N_query)xN_way, one-hot
"""
# init
eps = np.finfo(float).eps
[support, s_labels, query, q_labels] = inputs
num_classes = s_labels.shape[1]
num_support = int(s_labels.shape[0] / num_classes)
num_queries = int(query.shape[0] / num_classes)
# Step1: Embedding
inp = torch.cat((support,query), 0)
emb_all = self.encoder(inp).view(-1,1600)
N, d = emb_all.shape[0], emb_all.shape[1]
# Step2: Graph Construction
## sigmma
if self.args['rn'] in [30,300]:
self.sigma = self.relation(emb_all, self.args['rn'])
## W
emb_all = emb_all / (self.sigma+eps) # N*d
emb1 = torch.unsqueeze(emb_all,1) # N*1*d
emb2 = torch.unsqueeze(emb_all,0) # 1*N*d
W = ((emb1-emb2)**2).mean(2) # N*N*d -> N*N
W = torch.exp(-W/2)
## keep top-k values
if self.args['k']>0:
topk, indices = torch.topk(W, self.args['k'])
mask = torch.zeros_like(W)
mask = mask.scatter(1, indices, 1)
mask = ((mask+torch.t(mask))>0).type(torch.float32) # union, kNN graph
#mask = ((mask>0)&(torch.t(mask)>0)).type(torch.float32) # intersection, kNN graph
W = W*mask
## normalize
D = W.sum(0)
D_sqrt_inv = torch.sqrt(1.0/(D+eps))
D1 = torch.unsqueeze(D_sqrt_inv,1).repeat(1,N)
D2 = torch.unsqueeze(D_sqrt_inv,0).repeat(N,1)
S = D1*W*D2
# Step3: Label Propagation, F = (I-\alpha S)^{-1}Y
ys = s_labels
yu = torch.zeros(num_classes*num_queries, num_classes).cuda(0)
#yu = (torch.ones(num_classes*num_queries, num_classes)/num_classes).cuda(0)
y = torch.cat((ys,yu),0)
F = torch.matmul(torch.inverse(torch.eye(N).cuda(0)-self.alpha*S+eps), y)
Fq = F[num_classes*num_support:, :] # query predictions
# Step4: Cross-Entropy Loss
ce = nn.CrossEntropyLoss().cuda(0)
## both support and query loss
gt = torch.argmax(torch.cat((s_labels, q_labels), 0), 1)
loss = ce(F, gt)
## acc
predq = torch.argmax(Fq,1)
gtq = torch.argmax(q_labels,1)
correct = (predq==gtq).sum()
total = num_queries * num_classes
acc = 1.0 * correct.float() / float(total)
return loss, acc