-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
52 lines (43 loc) · 1.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-
import engine
from torch.utils.data import DataLoader
import torch
# GET COMMANDLINE ARGS
args = engine.get_args()
# LOADING MODEL
model = engine.load_model(args.model,args.noc)
# DATAPATHS
train_images = args.dataset_path+'/train/images'
train_labels = args.dataset_path+'/train/labels'
val_images = args.dataset_path+'/validation/images'
val_labels = args.dataset_path+'/validation/labels'
test_images = args.dataset_path+'/test/images'
test_labels = args.dataset_path+'/test/labels'
# DATA LOADERS
train_loader = DataLoader(engine.getDataset(train_images,
train_labels,
size = (360,480)),
batch_size=args.batch_size,
num_workers=args.num_of_workers,
shuffle=True)
val_loader = DataLoader(engine.getDataset(val_images,
val_labels,
size = (360,480)),
batch_size=args.batch_size,
num_workers=args.num_of_workers,
shuffle=False)
test_loader = DataLoader(engine.getDataset(test_images,
test_labels,
size = (360,480)),
batch_size=args.batch_size,
num_workers=args.num_of_workers,
shuffle=False)
# TRAINING
if args.fresh_train:
trainer = engine.Trainer(model.cuda(),train_loader,val_loader,args.save_path,args.max_epochs,args.noc)
trained_model = trainer.train()
else:
trained_model = model.load_state_dict(torch.load(args.save_path+'/best_model.pth')).cuda()
# TESTING
tester_test = engine.Tester(trained_model,test_loader,args.save_path+'/eval_test',args.noc)
tester_test.test()