-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsweep_singlegpu.py
117 lines (85 loc) · 2.73 KB
/
sweep_singlegpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# %%
from dotenv import load_dotenv
load_dotenv()
import torch
from models.vgg import VGG
#from models.resnet import resnet_models
from models.resnet_pretrained import resnet_models
from models.vgg_pretrained import VGG16, VGG13, VGG11
import wandb
import yaml
from utils.utils import set_seeds
import os
os.environ["WANDB__SERVICE_WAIT"] = "300"
set_seeds()
from trainers.trainer import Trainer
from utils.data_utils import CustomDataHandler, get_basic_transform
from configurations.configs import (
OptimizerType,
SchedulerType,
DataHandlerConfig,
OptimizerConfig,
SchedulerConfig,
)
from utils.main_utils import get_optimizer, get_scheduler
with open("configurations/sweep_config.yml") as file:
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
# %%
def run_sweep(config: dict = None):
# tell wandb to get started
global sweep_id
run = wandb.init(config=config)
config = wandb.config
dataset_name= sweep_config["parameters"]["dataset_name"]["value"]
train_t, test_t = get_basic_transform(dataset_name)
dh_config = DataHandlerConfig(
batch_size=config.batch_size,
multi_gpu=config.multi_gpu,
train_slice=config.train_slice,
test_slice=config.test_slice,
train_transform=train_t,
test_transform=test_t,
dataset=config.dataset_name,
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.SGD,
lr=config.lr,
wd=config.wd,
momentum=config.momentum,
)
scheduler_config = SchedulerConfig(
scheduler_type= SchedulerType.CosineAnnealingLR,
max_epochs=config.max_epochs,
)
custom_dataclass = CustomDataHandler(config=dh_config)
loaders = custom_dataclass.loaders
#model = VGG16(model_name=config.model_name, n_class=config.n_class)
model= resnet_models(model_name=config.model_name,n_class=config.n_class)
criterion = torch.nn.CrossEntropyLoss()
optimizer = get_optimizer(
config=optimizer_config,
model=model,
)
scheduler = get_scheduler(
config=scheduler_config,
optimizer=optimizer,
max_epochs=config.max_epochs,
)
trainer = Trainer(
model=model,
loaders=loaders,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
gpu_id=config.gpu_id,
)
trainer.pipeline(
max_epochs=config.max_epochs,
patience=config.patience,
wandb_flag=True,
sweep_id=sweep_id,
early_stop_verbose=config.early_stop_verbose,
)
n_runs = sweep_config["parameters"]["n_runs"]["value"]
sweep_id = wandb.sweep(sweep_config, project='frft-demo')
wandb.agent(sweep_id, function=run_sweep, count=n_runs)