-
Notifications
You must be signed in to change notification settings - Fork 150
/
Copy pathrewards.py
56 lines (48 loc) · 2.18 KB
/
rewards.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import sys
def compute_reward(seq, actions, ignore_far_sim=True, temp_dist_thre=20, use_gpu=False):
"""
Compute diversity reward and representativeness reward
Args:
seq: sequence of features, shape (1, seq_len, dim)
actions: binary action sequence, shape (1, seq_len, 1)
ignore_far_sim (bool): whether to ignore temporally distant similarity (default: True)
temp_dist_thre (int): threshold for ignoring temporally distant similarity (default: 20)
use_gpu (bool): whether to use GPU
"""
_seq = seq.detach()
_actions = actions.detach()
pick_idxs = _actions.squeeze().nonzero().squeeze()
num_picks = len(pick_idxs) if pick_idxs.ndimension() > 0 else 1
if num_picks == 0:
# give zero reward is no frames are selected
reward = torch.tensor(0.)
if use_gpu: reward = reward.cuda()
return reward
_seq = _seq.squeeze()
n = _seq.size(0)
# compute diversity reward
if num_picks == 1:
reward_div = torch.tensor(0.)
if use_gpu: reward_div = reward_div.cuda()
else:
normed_seq = _seq / _seq.norm(p=2, dim=1, keepdim=True)
dissim_mat = 1. - torch.matmul(normed_seq, normed_seq.t()) # dissimilarity matrix [Eq.4]
dissim_submat = dissim_mat[pick_idxs,:][:,pick_idxs]
if ignore_far_sim:
# ignore temporally distant similarity
pick_mat = pick_idxs.expand(num_picks, num_picks)
temp_dist_mat = torch.abs(pick_mat - pick_mat.t())
dissim_submat[temp_dist_mat > temp_dist_thre] = 1.
reward_div = dissim_submat.sum() / (num_picks * (num_picks - 1.)) # diversity reward [Eq.3]
# compute representativeness reward
dist_mat = torch.pow(_seq, 2).sum(dim=1, keepdim=True).expand(n, n)
dist_mat = dist_mat + dist_mat.t()
dist_mat.addmm_(1, -2, _seq, _seq.t())
dist_mat = dist_mat[:,pick_idxs]
dist_mat = dist_mat.min(1, keepdim=True)[0]
#reward_rep = torch.exp(torch.FloatTensor([-dist_mat.mean()]))[0] # representativeness reward [Eq.5]
reward_rep = torch.exp(-dist_mat.mean())
# combine the two rewards
reward = (reward_div + reward_rep) * 0.5
return reward