From 03b027a2924f6a0fdc717a80220c56d5ab26bccb Mon Sep 17 00:00:00 2001 From: tonghe90 Date: Fri, 1 Oct 2021 17:45:48 +0930 Subject: [PATCH] ad dice loss --- model/pointgroup/pointgroup.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/model/pointgroup/pointgroup.py b/model/pointgroup/pointgroup.py index accf2a5..99a53dc 100644 --- a/model/pointgroup/pointgroup.py +++ b/model/pointgroup/pointgroup.py @@ -12,6 +12,15 @@ import numpy as np from model.transformer import TransformerEncoder +def dice_coefficient(x, target, weight): + eps = 1e-5 + # n_inst = x.size(0) + # x = x.reshape(n_inst, -1) + # target = target.reshape(n_inst, -1) + intersection = (x * target).sum() + union = (x ** 2.0).sum() + (target ** 2.0).sum() + eps + loss = 1. - (2 * intersection / union) + return loss class ResidualBlock(SparseModule): def __init__(self, in_channels, out_channels, norm_fn, indice_key=None): @@ -831,6 +840,9 @@ def loss_fn(loss_inp, epoch): loss_out['score_loss'] = (score_loss, proposals_offset_shift.size(0)-1) loss += (cfg.loss_weight[3] * score_loss) + dice_loss = dice_coefficient(torch.sigmoid(mask_logits.view(-1)), inst_gt_mask.view(-1), weights.view(-1)) + loss_out['dice_loss'] = (dice_loss, dice_loss.new_tensor(1.0)) + return loss, loss_out, infos