forked from minerva-ml/open-solution-data-science-bowl-2018
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
81 lines (67 loc) · 3.96 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import torch.optim as optim
from steps.pytorch.architectures.unet import UNet, UNetMultitask
from steps.pytorch.callbacks import CallbackList, TrainingMonitor, ValidationMonitor, ModelCheckpoint, \
ExperimentTiming, ExponentialLRScheduler, EarlyStopping
from steps.pytorch.models import Model
from steps.pytorch.validation import segmentation_loss
from utils import sigmoid
from callbacks import NeptuneMonitorSegmentation
class PyTorchUNet(Model):
def __init__(self, architecture_config, training_config, callbacks_config):
super().__init__(architecture_config, training_config, callbacks_config)
self.model = UNet(**architecture_config['model_params'])
self.weight_regularization = weight_regularization_unet
self.optimizer = optim.Adam(self.weight_regularization(self.model, **architecture_config['regularizer_params']),
**architecture_config['optimizer_params'])
self.loss_function = [('mask', segmentation_loss, 1.0)]
self.callbacks = callbacks_unet(self.callbacks_config)
def transform(self, datagen, validation_datagen=None):
outputs = self._transform(datagen, validation_datagen)
for name, prediction in outputs.items():
prediction_ = [sigmoid(np.squeeze(mask)) for mask in prediction]
outputs[name] = np.array(prediction_)
return outputs
class PyTorchUNetMultitask(Model):
def __init__(self, architecture_config, training_config, callbacks_config):
super().__init__(architecture_config, training_config, callbacks_config)
self.model = UNetMultitask(**architecture_config['model_params'])
self.weight_regularization = weight_regularization_unet
self.optimizer = optim.Adam(self.weight_regularization(self.model, **architecture_config['regularizer_params']),
**architecture_config['optimizer_params'])
self.loss_function = [('mask', segmentation_loss, 0.4),
('contour', segmentation_loss, 0.5),
('center', segmentation_loss, 0.1)]
self.callbacks = callbacks_unet(self.callbacks_config)
def transform(self, datagen, validation_datagen=None):
outputs = self._transform(datagen, validation_datagen)
for name, prediction in outputs.items():
prediction_ = [sigmoid(np.squeeze(mask)) for mask in prediction]
outputs[name] = np.array(prediction_)
return outputs
def weight_regularization(model, regularize, weight_decay_conv2d, weight_decay_linear):
if regularize:
parameter_list = [{'params': model.features.parameters(), 'weight_decay': weight_decay_conv2d},
{'params': model.classifier.parameters(), 'weight_decay': weight_decay_linear},
]
else:
parameter_list = [model.parameters()]
return parameter_list
def weight_regularization_unet(model, regularize, weight_decay_conv2d):
if regularize:
parameter_list = [{'params': model.parameters(), 'weight_decay': weight_decay_conv2d},
]
else:
parameter_list = [model.parameters()]
return parameter_list
def callbacks_unet(callbacks_config):
experiment_timing = ExperimentTiming(**callbacks_config['experiment_timing'])
model_checkpoints = ModelCheckpoint(**callbacks_config['model_checkpoint'])
lr_scheduler = ExponentialLRScheduler(**callbacks_config['lr_scheduler'])
training_monitor = TrainingMonitor(**callbacks_config['training_monitor'])
validation_monitor = ValidationMonitor(**callbacks_config['validation_monitor'])
neptune_monitor = NeptuneMonitorSegmentation(**callbacks_config['neptune_monitor'])
early_stopping = EarlyStopping(**callbacks_config['early_stopping'])
return CallbackList(
callbacks=[experiment_timing, training_monitor, validation_monitor,
model_checkpoints, lr_scheduler, neptune_monitor, early_stopping])