forked from HKervadec/ai4mi_project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
70 lines (59 loc) · 2.57 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import segmentation_models_pytorch as smp
from torch import nn
from pathlib import Path
import torch
class SegmentationModelBase(nn.Module):
def initialize_base(self, kwargs):
self.encoder_name = kwargs['encoder_name']
self.encoder_weights = 'imagenet'
self.unfreeze_enc_last_n_layers = kwargs['unfreeze_enc_last_n_layers'] # How many of the last enc layers are unfrozen
# Freeze all the layers of the encoder except the last n
def freeze_encoder_layers(self):
enc_num_layers = 0
for _ in self.encoder.children():
enc_num_layers += 1
freeze_first_n_layers = enc_num_layers - self.unfreeze_enc_last_n_layers
for layer_num, layer in enumerate(self.encoder.children()):
if layer_num < freeze_first_n_layers:
for param in layer.parameters():
param.requires_grad = False
print(f"> Initialized encoder {self.encoder_name} with first {freeze_first_n_layers}/{enc_num_layers} layers frozen")
def init_weights(self, args):
# If in evaluation mode, load the model from the file
if args.evaluation:
print(f"Loading model weights from {args.dest} ...")
trained_weights_path = args.dest / "bestweights.pt"
device = 'cpu' if not args.gpu else None
self.load_state_dict(torch.load(trained_weights_path, map_location=device))
else:
pass
class UNet(smp.Unet, SegmentationModelBase):
def __init__(self, in_channels, out_channels, **kwargs):
self.initialize_base(kwargs)
super().__init__(
encoder_name=self.encoder_name,
encoder_weights=self.encoder_weights,
in_channels=in_channels,
classes=out_channels
)
self.freeze_encoder_layers()
class UNetPlusPlus(smp.UnetPlusPlus, SegmentationModelBase):
def __init__(self, in_channels, out_channels, **kwargs):
self.initialize_base(kwargs)
super().__init__(
encoder_name=self.encoder_name,
encoder_weights=self.encoder_weights,
in_channels=in_channels,
classes=out_channels
)
self.freeze_encoder_layers()
class DeepLabV3Plus(smp.DeepLabV3Plus, SegmentationModelBase):
def __init__(self, in_channels, out_channels, **kwargs):
self.initialize_base(kwargs)
super().__init__(
encoder_name=self.encoder_name,
encoder_weights=self.encoder_weights,
in_channels=in_channels,
classes=out_channels
)
self.freeze_encoder_layers()