-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrobot_train_hyperparameter_tuning.py
50 lines (37 loc) · 1.99 KB
/
robot_train_hyperparameter_tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import yaml
import argparse
from robot_move import MoveRobot
if __name__ == "__main__":
""" Code for hyperparameter tuning TWR agent """
parser = argparse.ArgumentParser('robot Train', add_help=False)
# model config
parser.add_argument("--config-file",
default="./configs/main.yaml",
metavar="FILE", help="path to config file")
parser.add_argument("--gpu", type=str, default="0", help="GPU id")
args = parser.parse_args()
model_config_path = args.config_file
# reads model config
with open(model_config_path, 'r') as f:
model_config = yaml.safe_load(f)
# read config file variables
model_config['config_path'] = model_config_path
model_config['gpu_id'] = args.gpu
model_config['mode'] = 'train'
print(f'\n ===== Config File Path given: {model_config_path} =====\n')
print(yaml.dump(model_config, sort_keys=False))
# get hyperparameter tuning variables from config file
hyperparam_type = model_config['hyperparameter_tuning']['type']
hyperparameter_tuning_variable = model_config['hyperparameter_tuning']['tuning_variable']
hyperparameter_value_list = model_config['hyperparameter_tuning']['value_list']
print(" ===== Hyperparameter Tuning ===== ")
print(f"hyperparameter_tuning_variable = {hyperparameter_tuning_variable}")
print(f"hyperparameter_value_list = {hyperparameter_value_list}\n")
for run_idx, hyperparameter_value in enumerate(hyperparameter_value_list):
print(f" ===== Run {run_idx} : {hyperparameter_tuning_variable} = {hyperparameter_value} ===== ")
# overwrites hyperparameter_tuning_variable
model_config[hyperparam_type][hyperparameter_tuning_variable] = hyperparameter_value
robot_trainer = MoveRobot(model_config,
hyperparameter_tuning_variable = hyperparameter_tuning_variable, \
hyperparameter_value = hyperparameter_value)
robot_trainer.run()