-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmta.py
87 lines (66 loc) · 2.83 KB
/
mta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn.functional as F
def gaussian_kernel(mu, bandwidth, datapoints):
dist = torch.norm(datapoints - mu,dim=-1, p=2)
density = torch.exp(-dist**2/(2*bandwidth**2))
return density
def solve_mta(model, inputs, args):
with torch.no_grad():
with torch.cuda.amp.autocast():
image_features, text_features, logit_scale = model(inputs, features=True)
logits = image_features @ text_features.t() * logit_scale
lambda_y = args.lambda_y
lambda_q = args.lambda_q
max_iter = 5
temperature = 1
batch_size = image_features.shape[0]
# bandwidth
dist = torch.cdist(image_features, image_features)
sorted_dist, _ = torch.sort(dist, dim=1)
k = int(0.3 * (image_features.shape[0]-1))
selected_distances = sorted_dist[:, 1:k+1]**2 # exclude the distance to the point itself
mean_distance = torch.mean(selected_distances, dim=1)
bandwidth = torch.sqrt(0.5 * mean_distance)
# Affinity matrix based on logits
affinity_matrix = (logits/temperature).softmax(1) @ (logits/temperature).softmax(1).t()
# Inlierness scores initialization: uniform
y = torch.ones(batch_size, device=image_features.device)/batch_size
# Mode initialization: original image embedding
mode_init = image_features[0]
mode = mode_init
convergence = False
th = 1e-6
iter = 0
while not convergence:
###################
# Inlierness step #
###################
density = gaussian_kernel(mode, bandwidth, image_features)
convergence_inlierness = False
i = 0
while not convergence_inlierness:
i+=1
old_y = y
weighted_affinity = affinity_matrix * y.unsqueeze(0)
y = F.softmax(1/lambda_y * (density + lambda_q * torch.sum(weighted_affinity, dim=1)), dim=-1)
if torch.norm(old_y - y)<th or i>= max_iter:
convergence_inlierness = True
#############
# Mode step #
#############
convergence_mode = False
i=0
while not convergence_mode:
i+=1
old_mode = mode
density = gaussian_kernel(mode, bandwidth, image_features)
weighted_density = density * y
mode = torch.sum(weighted_density.unsqueeze(1)* image_features, dim=0)/torch.sum(weighted_density)
mode /= mode.norm(p=2, dim=-1)
if torch.norm(old_mode - mode)<th or i>= max_iter:
convergence_mode = True
iter +=1
if iter >= max_iter:
convergence = True
output = mode.unsqueeze(0) @ text_features.t() * logit_scale
return output