-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain.py
93 lines (71 loc) · 3.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from configs import ConfigLoader
from datetime import datetime
from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, NetIO, Trainer
import argparse
import numpy as np
import os
from torch.utils.data import DataLoader
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def prepare_environment(args):
config = ConfigLoader.load(args.config_path.replace('\n', '').replace('\r', ''))
date = datetime.now().strftime("%Y%m%d")
if args.save_dir is not None:
config.output["save_dir"] = args.save_dir
config.output["save_dir"] = "{}_{}".format(date, config.output["save_dir"])
config.model['name'] = args.model
seed = config.environment['seed']
set_seed(seed)
if config.environment.cuda.flag:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
return config
def build_dataloader(config):
batch_size = config.train['batch_size']
transform_name = config.dataset['transform_name']
dataset_name = config.dataset['name']
train_transform, val_transform = TransformBuilder.load(transform_name)
trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True)
valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
return (train_loader, trainset_config), (val_loader, valset_config)
def build_trainer(config):
netio = NetIO(config)
model = ModelBuilder.load(config.model['name'], num_classes=config.model['num_classes'])
if config.model['resume']:
model = netio.load_file(model, config.model['ckpt'])
loss_func1 = LossBuilder.load("CrossEntropyLoss")
loss_wrapper = LossWrapper([loss_func1], [config.train.criterion['loss_weights']])
if config.environment.cuda.flag:
model = model.cuda()
loss_wrapper = loss_wrapper.cuda()
trainer = Trainer(config=config, model=model, wrapper=loss_wrapper, ioer=netio)
return trainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, default="./configs/20220223_cifar100.yml")
parser.add_argument('--model', type=str, default='resnet34')
parser.add_argument('--save_dir', type=str)
args = parser.parse_args()
config = prepare_environment(args)
start_epoch = config.train['start_epoch']
max_epoch = config.train['epochs'] + 1
trainer = build_trainer(config)
(train_loader, trainset_config), (val_loader, valset_config) = build_dataloader(config)
if trainer.logger is not None:
trainer.logger.info(trainset_config)
trainer.logger.info(valset_config)
trainer.logger.info(config.model)
trainer.logger.info(config.train)
trainer.logger.info(config.output)
for epoch in range(start_epoch, max_epoch):
trainer.train(epoch, train_loader)
trainer.validate(epoch, val_loader)
trainer.logger.info("best metric: {}".format(trainer.ioer.get_best_score()))
if __name__ == '__main__':
main()