-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
44 lines (38 loc) · 1.5 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
class SoftIoULoss(nn.Module):
def __init__(self):
super(SoftIoULoss, self).__init__()
def forward(self, preds, gt_masks):
if isinstance(preds, list) or isinstance(preds, tuple):
loss_total = 0
for i in range(len(preds)):
pred = preds[i]
smooth = 1
intersection = pred * gt_masks
loss = (intersection.sum() + smooth) / (pred.sum() + gt_masks.sum() -intersection.sum() + smooth)
loss = 1 - loss.mean()
loss_total = loss_total + loss
return loss_total / len(preds)
else:
pred = preds
smooth = 1
intersection = pred * gt_masks
loss = (intersection.sum() + smooth) / (pred.sum() + gt_masks.sum() -intersection.sum() + smooth)
loss = 1 - loss.mean()
return loss
class ISNetLoss(nn.Module):
def __init__(self):
super(ISNetLoss, self).__init__()
self.softiou = SoftIoULoss()
self.bce = nn.BCELoss()
self.grad = Get_gradient_nopadding()
def forward(self, preds, gt_masks):
edge_gt = self.grad(gt_masks.clone())
### img loss
loss_img = self.softiou(preds[0], gt_masks)
### edge loss
loss_edge = 10 * self.bce(preds[1], edge_gt)+ self.softiou(preds[1].sigmoid(), edge_gt)
return loss_img + loss_edge