-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest-classifier.py
95 lines (81 loc) · 2.48 KB
/
test-classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import logging
import numpy as np
import os
import sys
import torch
import torch.backends.cudnn as cudnn
import torch.cuda as tcuda
import torch.nn as nn
import torchvision.models as models
import tqdm
import yaml
from tqdm import trange
from utils import IsValidFile, get_data_loader
from suepvision.smodels import (
LeNet5,
get_resnet18,
get_resnet50,
get_enet,
get_convnext
)
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.random.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
def execute(name, architecture, dataset, evaluation_pref, verbose):
test_loader = get_data_loader(
dataset['test'][0],
evaluation_pref['batch_size'],
evaluation_pref['workers'],
dataset['in_dim'],
0,
boosted=dataset['boosted'],
shuffle=False
)
model = eval(architecture)()
model.load_state_dict(torch.load("models/{}.pth".format(name)))
cudnn.benchmark = True
test_results = torch.tensor([])
if verbose:
tr = trange(len(test_loader), file=sys.stdout)
tr.set_description('Testing')
model.eval()
with torch.no_grad():
for images, targets in test_loader:
count = len(targets)
targets = tcuda.LongTensor(targets, device=0)
outputs = model(images)
preds = torch.argmax(outputs, dim=1)
batch_results = torch.cat((targets.reshape(-1, 1), outputs), 1)
test_results = torch.cat((test_results, batch_results), 0)
if verbose:
tr.update(1)
if verbose:
tr.close()
np.save(
"models/{}-results.npy".format(name),
test_results.detach().cpu().numpy()
)
if __name__ == '__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
parser = argparse.ArgumentParser('Test SUEP Classifier')
parser.add_argument('name', type=str, help='Model name')
parser.add_argument('-c', '--config',
action=IsValidFile,
type=str,
help='Path to config file',
default='config.yml')
parser.add_argument('-v', '--verbose',
action='store_true',
help='Output verbosity')
args = parser.parse_args()
config = yaml.safe_load(open(args.config))
execute(
args.name,
config['architecture'],
config['dataset'],
config['evaluation_pref'],
args.verbose
)