-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_ssl.py
111 lines (90 loc) · 3.22 KB
/
main_ssl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import wandb
import logging
import sys
from torch.utils.data import DataLoader
from accelerate import Accelerator
from omegaconf import OmegaConf
from arguments import parser
from datasets import create_dataset
from log import setup_default_logging
from metric_learning.factory import create_metric_learning
from metric_learning.models import MetricModel
from query_strategies import torch_seed
from query_strategies.scheds import create_scheduler
from query_strategies.optims import create_optimizer
_logger = logging.getLogger('train')
def run(cfg):
# set accelerator
accelerator = Accelerator(
gradient_accumulation_steps = cfg.TRAIN.grad_accum_steps,
mixed_precision = cfg.TRAIN.mixed_precision
)
setup_default_logging()
torch_seed(cfg.DEFAULT.seed)
# set device
_logger.info('Device: {}'.format(accelerator.device))
# load dataset
trainset, _, _ = create_dataset(
datadir = cfg.DATASET.datadir,
dataname = cfg.DATASET.name,
img_size = cfg.DATASET.img_size,
mean = cfg.DATASET.mean,
std = cfg.DATASET.std,
aug_info = cfg.DATASET.aug_info,
**cfg.DATASET.get('params', {})
)
trainloader = DataLoader(
trainset,
batch_size = cfg.DATASET.batch_size,
num_workers = cfg.DATASET.num_workers,
)
# make save directory
savedir = os.path.join(cfg.DEFAULT.savedir, cfg.DATASET.name, cfg.SSL.method, cfg.MODEL.name)
savepath = os.path.join(savedir, 'ckp.pt')
assert not os.path.isfile(savepath), f'{savepath} already exists'
os.makedirs(savedir, exist_ok=True)
# save configs
OmegaConf.save(cfg, os.path.join(savedir, 'configs.yaml'))
# initialize wandb
if cfg.TRAIN.wandb.use:
wandb.init(name=cfg.DEFAULT.exp_name, project=cfg.TRAIN.wandb.project_name, entity=cfg.TRAIN.wandb.entity, config=OmegaConf.to_container(cfg))
# metric learning
vis_encoder = MetricModel(
modelname = cfg.MODEL.name,
pretrained = cfg.MODEL.pretrained,
**cfg.MODEL.get('params', {})
)
# optimizer
optimizer = create_optimizer(opt_name=cfg.OPTIMIZER.name, model=vis_encoder, lr=cfg.OPTIMIZER.lr, opt_params=cfg.OPTIMIZER.params)
scheduler = create_scheduler(
sched_name = cfg.SCHEDULER.name,
optimizer = optimizer,
epochs = cfg.TRAIN.epochs,
params = cfg.SCHEDULER.params,
warmup_params = cfg.SCHEDULER.get('warmup_params', {})
)
SSLTrainer = create_metric_learning(
method_name = cfg.SSL.method,
savepath = savepath,
accelerator = accelerator,
seed = cfg.DEFAULT.seed,
dataname = cfg.DATASET.name,
img_size = cfg.DATASET.img_size,
ssl_params = cfg.SSL.get('params', {})
)
SSLTrainer.fit(
epochs = cfg.TRAIN.epochs,
vis_encoder = vis_encoder,
dataloader = trainloader,
optimizer = optimizer,
scheduler = scheduler,
device = accelerator.device
)
if __name__=='__main__':
# CSI training
sys.setrecursionlimit(10000)
# config
cfg = parser()
# run
run(cfg)