-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
57 lines (43 loc) · 1.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
'''
Created on Jun 22, 2021
@author: Navid Dianati
'''
import logging
from MOA_L1000 import config, utils, trainer , dataload
logger = logging.getLogger('main')
logging.basicConfig()
logger.setLevel('INFO')
def run_model(params):
utils.export_params(params)
data_loader = eval(params.get('data_loader'))
SEEDS = params.get("SEEDS")
# Load data and perform transformations
X, y, w, X_holdout, y_holdout = data_loader()
# Instantiate trainer
tr = trainer.Trainer(**params)
tr.set_training_data(X, y, w, X_holdout, y_holdout)
for seed in SEEDS:
utils.seed_all(seed)
tr.seed = seed
# Run cross-validation on this fold
try:
tr.run_cv(
**params
)
except KeyboardInterrupt:
logger.info('Training aborted.')
break
def main():
for params in [
# config.get_params_DNN5(),
# config.get_params_DNN6(),
# config.get_params_DNN10(),
# config.get_params_DNN11(),
# config.get_params_DNN8(),
# config.get_params_DNN15(),
# config.get_params_DNN16(),
config.get_params_DNN17(),
]:
run_model(params)
if __name__ == "__main__":
main()