forked from thematrixduo/MXGNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_RAVEN.py
105 lines (84 loc) · 4.02 KB
/
test_RAVEN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import argparse
import time
import numpy as np
import math
import torch.optim as optim
from torchvision import datasets,transforms
import torch.utils
from model_RAVEN import Model
from torch.utils.data import DataLoader
import torch.nn.utils
import matplotlib.pyplot as plt
import itertools
from data_utility import ToTensor
from data_utility import dataset_raven as dataset
def test(model,testloader,validloader,device,args):
model.eval()
for epoch in range(1):
if args.valid_result:
total_correct = 0.0
for (data,label,meta_target) in validloader:
data = data.view(-1,16,args.image_size,args.image_size)
label = label.view(-1)
meta_target = meta_target.view(-1,9)
data = data.to(device)
label = label.to(device)
meta_target = meta_target.to(device)
_,score_vec = model(data,label,meta_target)
_,pred = torch.max(score_vec,1)
c = (pred == label).squeeze()
total_correct += torch.sum(c).item()
accuracy = total_correct/(validData.__len__())
print('Validation Accuracy:',accuracy)
total_correct = 0.0
for (data,label,meta_target) in testloader:
data = data.view(-1,16,args.image_size,args.image_size)
label = label.view(-1)
meta_target = meta_target.view(-1,9)
data = data.to(device)
label = label.to(device)
meta_target = meta_target.to(device)
_,score_vec = model(data,label,meta_target)
_,pred = torch.max(score_vec,1)
c = (pred == label).squeeze()
meta_target_np = meta_target.cpu().numpy()
batch_correct = torch.sum(c).item()
total_correct += batch_correct
accuracy = total_correct/(testloader.dataset.__len__())
print('Test Accuracy:',accuracy)
def main():
parser = argparse.ArgumentParser(description='RAVEN test args')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--image-size', type=float, default=80, metavar='IMSIZE',
help='input image size (default: 80)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--multi-gpu', action='store_true', default=False,
help='parallel training on multiple GPUs')
parser.add_argument('--valid_result', action='store_true', default=False,
help='compute results on validation dataset')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--model-path-name', default='', type=str, metavar='PATH',
help='The path+name of model to be loaded')
args = parser.parse_args()
torch.set_default_tensor_type('torch.FloatTensor')
device = torch.device("cpu" if args.no_cuda else "cuda")
test_data = dataset(args.data, "test", args.image_size, transform=transforms.Compose([ToTensor()]))
valid_data = dataset(args.data, "val", args.image_size, transform=transforms.Compose([ToTensor()]))
testloader = DataLoader(test_data, batch_size=args.batch_size, num_workers=8)
validloader = DataLoader(valid_data, batch_size=args.batch_size, num_workers=8)
model = Model(args.image_size,args.image_size)
if not args.no_cuda:
model.cuda()
if torch.cuda.device_count() > 1 and args.multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(args.model_path_name))
test(model,testloader,validloader,device,args)
if __name__ == '__main__':
main()