-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
103 lines (75 loc) · 3.26 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import logging
import os
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.io import write_png
from dataloader import ImageDataset, MultiImageDataset
from unet import UNet, UNetPlus
# python predict.py --exp 220610_162631 --validation --epochs 3
# python predict.py --exp 220607_163752 --single_image --validation --epochs 30
IMGSIZE = (66, 45)
checkpoint_root = Path('./checkpoints')
def save_result(pred, save_dir, fname):
path = save_dir / (fname + '.png')
# compression level?
write_png(pred, str(path))
def predict(model, device, save_dir, single_image=False, is_validation=False):
save_dir.mkdir(parents=True, exist_ok=True)
dataset_cl = ImageDataset if single_image else MultiImageDataset
batch_size = 4 if single_image else 1
if not is_validation:
test_dataset = dataset_cl(option='Test')
else:
test_dataset = dataset_cl(option='Validation', is_test=True)
fnames = sorted(list(set(['_'.join(s.split('_')[:-1]) for s in test_dataset.sem_list])))
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
# assert len(test_loader) == len(fnames)
model.eval()
for batch, org in enumerate(test_loader):
org = org.to(device=device)
with torch.no_grad():
pred = model(org)
pred = torch.sigmoid(pred)
pred = pred.float().cpu()
pred = pred.mean(axis=0)
pred = (pred * 255).type(torch.uint8)
save_result(pred, save_dir, fnames[batch])
def get_args():
parser = argparse.ArgumentParser(description='Depth estimation for SEM')
parser.add_argument('--epochs', '-e', type=int, default=3, help='Load model from a .pth file')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--exp', help='experiment date', default='first')
parser.add_argument('--validation', action='store_true', default=False)
parser.add_argument('--single_image', action='store_true', default=False, help='single imgae input')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
if args.single_image:
model = UNet(n_channels=1, n_classes=1, bilinear=args.bilinear)
else:
model = UNetPlus(n_channels=1, bilinear=args.bilinear)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.validation:
save_root = Path('./results_validation/')
else:
save_root = Path('./results')
checkpoint_dir = checkpoint_root / args.exp
save_dir = save_root / args.exp
model_path = checkpoint_dir / f'checkpoint_epoch{args.epochs}.pth'
logging.info(f'Loading model {model_path}')
logging.info(f'Using device {device}')
model.to(device=device)
model.load_state_dict(torch.load(model_path, map_location=device))
logging.info('Model loaded')
predict(model=model,
device=device,
save_dir=save_dir,
single_image=args.single_image,
is_validation=args.validation)