-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
114 lines (92 loc) · 3.4 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import time
import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from network_files import MaskRCNN
from backbone import resnet50_fpn_backbone
from draw_box_utils import draw_objs
def create_model(num_classes, box_thresh=0.5):
backbone = resnet50_fpn_backbone()
model = MaskRCNN(
backbone,
num_classes=num_classes,
rpn_score_thresh=box_thresh,
box_score_thresh=box_thresh,
)
return model
def time_synchronized():
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def main():
num_classes = 20 # 不包含背景
box_thresh = 0.5
weights_path = "./save_weights/model_14.pth"
img_path = "./images/000074.jpg"
label_json_path = "./pascal_voc_indices.json"
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
# create model
model = create_model(num_classes=num_classes + 1, box_thresh=box_thresh)
# load train weights
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
weights_dict = torch.load(weights_path, map_location="cpu")
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
model.load_state_dict(weights_dict)
model.to(device)
# read class_indict
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(
label_json_path
)
with open(label_json_path, "r") as json_file:
category_index = json.load(json_file)
# load image
assert os.path.exists(img_path), f"{img_path} does not exits."
original_img = Image.open(img_path).convert("RGB")
# from pil image to tensor, do not normalize image
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
model.eval() # 进入验证模式
with torch.no_grad():
# init
img_height, img_width = img.shape[-2:]
init_img = torch.zeros((1, 3, img_height, img_width), device=device)
model(init_img)
t_start = time_synchronized()
predictions = model(img.to(device))[0]
t_end = time_synchronized()
print("inference+NMS time: {}".format(t_end - t_start))
predict_boxes = predictions["boxes"].to("cpu").numpy()
predict_classes = predictions["labels"].to("cpu").numpy()
predict_scores = predictions["scores"].to("cpu").numpy()
predict_mask = predictions['masks'].to("cpu").numpy()
predict_mask = np.squeeze(
predict_mask, axis=1
) # [batch, 1, h, w] -> [batch, h, w]
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
return
plot_img = draw_objs(
original_img,
boxes=predict_boxes,
classes=predict_classes,
scores=predict_scores,
masks=predict_mask,
category_index=category_index,
line_thickness=3,
font="arial.ttf",
font_size=20,
)
plt.imshow(plot_img)
plt.show()
# 保存预测的图片结果
result_path, ext = os.path.splitext(img_path)
plot_img.save(result_path + "_result" + ext)
if __name__ == "__main__":
main()