-
Notifications
You must be signed in to change notification settings - Fork 14
/
loss.py
43 lines (36 loc) · 1.1 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
class lossAV(nn.Module):
def __init__(self):
super(lossAV, self).__init__()
self.criterion = nn.BCELoss()
self.FC = nn.Linear(128, 2)
def forward(self, x, labels = None, r = 1):
x = x.squeeze(1)
x = self.FC(x)
if labels == None:
predScore = x[:,1]
predScore = predScore.t()
predScore = predScore.view(-1).detach().cpu().numpy()
return predScore
else:
x1 = x / r
x1 = F.softmax(x1, dim = -1)[:,1]
nloss = self.criterion(x1, labels.float())
predScore = F.softmax(x, dim = -1)
predLabel = torch.round(F.softmax(x, dim = -1))[:,1]
correctNum = (predLabel == labels).sum().float()
return nloss, predScore, predLabel, correctNum
class lossV(nn.Module):
def __init__(self):
super(lossV, self).__init__()
self.criterion = nn.BCELoss()
self.FC = nn.Linear(128, 2)
def forward(self, x, labels, r = 1):
x = x.squeeze(1)
x = self.FC(x)
x = x / r
x = F.softmax(x, dim = -1)
nloss = self.criterion(x[:,1], labels.float())
return nloss