-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfine_grained_unet.py
106 lines (78 loc) · 3.96 KB
/
fine_grained_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from backboned_unet import get_backbone, unet
from loss import *
from keras.layers import Input, Conv2D
from keras.models import Model
from keras.optimizers import SGD, Adam
####### custom loss #######
def disc_loss(y_true, y_pred):
# use dice + bce (+border_dice when fine-tuning)
alpha = 10
beta = 100
gamma = 0
dice = dice_coef(y_true, y_pred)
bce_ = reweighting_bce(y_true, y_pred)
border_dice = border_dice_coef(y_true, y_pred)
return 1 - dice + alpha * bce_
def tuochu_loss(y_true, y_pred):
# use dice + bce + focal
alpha = 10
beta = 100
gamma = 0
dice = dice_coef(y_true, y_pred)
bce_ = reweighting_bce(y_true, y_pred)
focal = focal_loss(y_true, y_pred)
posmask = K.greater(K.sum(K.batch_flatten(y_true), axis=1, keepdims=True), 0)
return 1 - dice + alpha * bce_ + beta * focal * (1 + 0 * tf.cast(posmask, tf.float32))
def orig_loss(y_true, y_pred):
alpha = 1
return disc_loss(y_true[...,0:1], y_pred[...,0:1]) + alpha * tuochu_loss(y_true[...,1:2], y_pred[...,1:2])
def roi_loss(y_true, y_pred):
return tuochu_loss(y_true, y_pred)
####### custom metric #######
def metric_disc_dice(y_true, y_pred):
return dice_coef(y_true[...,0:1], y_pred[...,0:1])
def metric_tuochu_dice(y_true, y_pred):
return dice_coef(y_true[...,1:2], y_pred[...,1:2])
def metric_roi_tuochu_dice(y_true, y_pred):
return dice_coef(y_true, y_pred)
def metric_tuochu_recall(y_true, y_pred):
y_t = K.cast(K.greater(K.sum(K.batch_flatten(y_true[...,1]), axis=-1), 0), tf.int8) # N*1
y_p = K.cast(K.greater(K.sum(K.batch_flatten(y_pred[...,1]), axis=-1), 0), tf.int8)
return K.sum(y_t * y_p) / K.sum(y_t)
def metric_tuochu_precision(y_true, y_pred):
y_t = K.cast(K.greater(K.sum(K.batch_flatten(y_true[...,1]), axis=-1), 0), tf.int8)
y_p = K.cast(K.greater(K.sum(K.batch_flatten(y_pred[...,1]), axis=-1), 0), tf.int8)
return K.sum(y_t * y_p) / K.sum(y_p)
def metric_roi_tuochu_recall(y_true, y_pred):
y_t = K.cast(K.greater(K.sum(K.batch_flatten(y_true), axis=-1), 0), tf.int8)
y_p = K.cast(K.greater(K.sum(K.batch_flatten(y_pred), axis=-1), 0), tf.int8)
return K.sum(y_t * y_p) / K.sum(y_t)
def metric_roi_tuochu_precision(y_true, y_pred):
y_t = K.cast(K.greater(K.sum(K.batch_flatten(y_true), axis=-1), 0), tf.int8)
y_p = K.cast(K.greater(K.sum(K.batch_flatten(y_pred), axis=-1), 0), tf.int8)
return K.sum(y_t * y_p) / K.sum(y_p)
####### custom model #######
def fine_grained_unet(backbone_name='resnet50', input_shape=(256,256,1), output_channels=[2,1], stage=5):
unet_model = unet(backbone_name, input_shape, 1, stage)
backbone = Model(unet_model.input, unet_model.get_layer(index=-2).output)
full_input = Input(input_shape)
roi_input = Input(input_shape)
x1 = backbone(full_input)
x2 = backbone(roi_input)
# orig task branch
full_output = Conv2D(output_channels[0], kernel_size=3, padding='same', activation='sigmoid',
use_bias=True, kernel_initializer='glorot_uniform', name='orig_branch')(x1)
# fine-grained branch
roi_output = Conv2D(output_channels[1], kernel_size=3, padding='same', activation='sigmoid',
use_bias=True, kernel_initializer='glorot_uniform', name='roi_branch')(x2)
model = Model(inputs=[full_input, roi_input], outputs=[full_output, roi_output])
sgd = SGD(lr=1e-4, momentum=0.97, decay=1e-6, nesterov=True)
adam = Adam(lr=3e-4, decay=5e-6)
metriclst = {'orig_branch': [metric_disc_dice, metric_tuochu_dice, metric_tuochu_recall, metric_tuochu_precision],
'roi_branch': [metric_roi_tuochu_dice, metric_roi_tuochu_recall, metric_roi_tuochu_precision]}
model.compile(sgd, loss={'orig_branch': orig_loss, 'roi_branch': roi_loss},
loss_weights=[1., 1.], metrics=metriclst)
return model
if __name__ == '__main__':
model = fine_grained_unet(backbone_name='darknet52', input_shape=(256,256,1), output_channels=[2,1], stage=5)
# model.summary()