Skip to content

Commit

Permalink
ad dice loss
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghe90 committed Oct 1, 2021
1 parent f93b31a commit 03b027a
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions model/pointgroup/pointgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
import numpy as np
from model.transformer import TransformerEncoder

def dice_coefficient(x, target, weight):
eps = 1e-5
# n_inst = x.size(0)
# x = x.reshape(n_inst, -1)
# target = target.reshape(n_inst, -1)
intersection = (x * target).sum()
union = (x ** 2.0).sum() + (target ** 2.0).sum() + eps
loss = 1. - (2 * intersection / union)
return loss

class ResidualBlock(SparseModule):
def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
Expand Down Expand Up @@ -831,6 +840,9 @@ def loss_fn(loss_inp, epoch):
loss_out['score_loss'] = (score_loss, proposals_offset_shift.size(0)-1)
loss += (cfg.loss_weight[3] * score_loss)

dice_loss = dice_coefficient(torch.sigmoid(mask_logits.view(-1)), inst_gt_mask.view(-1), weights.view(-1))
loss_out['dice_loss'] = (dice_loss, dice_loss.new_tensor(1.0))



return loss, loss_out, infos
Expand Down

0 comments on commit 03b027a

Please sign in to comment.