-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmetric_test.py
executable file
·48 lines (43 loc) · 1.57 KB
/
metric_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import sys
import os
import torch
import time
from DataSets import create_datasets, create_dataloader
from Utils.eval import eval_metric_model
from Utils.tools import analysis_dataset
import argparse
import matplotlib.pyplot as plt
cur_path = os.path.abspath(os.path.dirname(__file__))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="测试-度量学习")
# 默认参数
parser.add_argument("--size", type=str, help="图像宽高", default="224,224")
parser.add_argument("--batch", type=int, help="推理batch", default=8)
# 参数
parser.add_argument("--txt", help="数据集路径", default=cur_path + "/Config/dataset.txt")
parser.add_argument("--process", help="图像预处理", default="ImageNet")
parser.add_argument("--mirror", help="融合镜像特征", default=False)
parser.add_argument("--weights", help="模型权重", required=True)
args = parser.parse_args()
args.size = [int(line) for line in args.size.split(",")]
device = "cuda" if torch.cuda.is_available() else "cpu"
# 直接加载model,而非model.state_dict
model = torch.load(args.weights, map_location="cpu")
model.to(device)
model.eval()
print(f"model info is {model.info}")
# 度量学习
assert model.info["task"] == "metric", "警告: 该模型不是度量学习模型"
# 数据集
dataset = analysis_dataset(args.txt)
# 统计
result = eval_metric_model(
model,
dataset,
args.size,
args.process,
args.batch,
"test",
args.mirror,
)
print(result)