-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_evaluate.py
58 lines (50 loc) · 2.15 KB
/
main_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from torch.utils.data import DataLoader
import os
# =====================================================
from models.model import Model
from config.config import config
from train.Ex import Evaluator
from utils.utils import load_model, print_config, makedir, read_json
from train.Dataset import CustomDataset
from train.Tokenizer import Tokenizer
# =====================================================
try:
import apex
from apex import amp, optimizers
except ImportError:
pass
if __name__ == '__main__':
opt = config()
tokenizer = Tokenizer()
tokenizer.build(path=opt.d_code_dataset_path, tokenizer_path=opt.tokenizer_path, category_map_path=opt.category_map_path, category_type_map_path=opt.category_type_map_path)
tokenizer.auto_binary_encode()
metatdata = read_json(parient_dir=opt.metadata_dataset_directory, name='dataset')
opt.input_size = metatdata.get("num_label")
opt.category_size = metatdata.get("num_category")
opt.category_type_size = metatdata.get("num_category_type")
model = Model(opt.model, opt)
print_config(opt, model)
parient_dir = os.path.join(opt.model_directory, str(opt.batch_size))
makedir(parient_dir, str(model.__class__.__name__))
save_model_path = os.path.join(parient_dir, str(model.__class__.__name__))
# Load model
load_model(save_model_path, model)
model.eval()
test_data = CustomDataset(parient_dir=opt.dataset_directory,
keys=opt.keys,
max_len=opt.max_appearances,
split='evaluate',
tokenizer=tokenizer,
top_k_evaluate=opt.top_k,
opt=opt)
evaluator = Evaluator(model,
tokenizer=tokenizer)
evaluator.evaluate(test_data,
batch_size=opt.batch_size,
top_k=opt.top_k,
max_len=opt.max_appearances,
report_directory=opt.report_directory,
predicted_directory=opt.predicted_directory,
metric_name=opt.metric_name,
shuffle=False,
device=opt.device)