-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
61 lines (56 loc) · 1.94 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from math import sqrt
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from utils import *
import os
from loss import *
from model import *
from skimage.feature.tests.test_orb import img
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class Net(nn.Module):
def __init__(self, model_name, mode):
super(Net, self).__init__()
self.model_name = model_name
self.cal_loss = SoftIoULoss()
if model_name == 'DNANet':
if mode == 'train':
self.model = DNANet(mode='train')
else:
self.model = DNANet(mode='test')
elif model_name == 'DNANet_BY':
if mode == 'train':
self.model = DNAnet_BY(mode='train')
else:
self.model = DNAnet_BY(mode='test')
elif model_name == 'ACM':
self.model = ACM()
elif model_name == 'ALCNet':
self.model = ALCNet()
elif model_name == 'ISNet':
if mode == 'train':
self.model = ISNet(mode='train')
else:
self.model = ISNet(mode='test')
self.cal_loss = ISNetLoss()
elif model_name == 'RISTDnet':
self.model = RISTDnet()
elif model_name == 'UIUNet':
if mode == 'train':
self.model = UIUNet(mode='train')
else:
self.model = UIUNet(mode='test')
elif model_name == 'U-Net':
self.model = Unet()
elif model_name == 'ISTDU-Net':
self.model = ISTDU_Net()
elif model_name == 'RDIAN':
self.model = RDIAN()
elif model_name == 'ResUNet':
self.model = ResUNet()
def forward(self, img):
return self.model(img)
def loss(self, pred, gt_mask):
loss = self.cal_loss(pred, gt_mask)
return loss