From c2abd9ac6a5a985c67ab74574e8d65e68b7cc606 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Wed, 13 Mar 2024 07:40:12 +0000 Subject: [PATCH] Add wandb prefix. Remove pre-trained checkpoint --- config/sweep/ppmi_sweep.yaml | 12 +++++++++++- sweep.py | 26 +++++++++++++++----------- sweep_command.sh | 5 +++-- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/config/sweep/ppmi_sweep.yaml b/config/sweep/ppmi_sweep.yaml index 55a9043..44b9a3d 100644 --- a/config/sweep/ppmi_sweep.yaml +++ b/config/sweep/ppmi_sweep.yaml @@ -1,11 +1,21 @@ +name: "PPMI Sweep" +description: "HPO for PPMI-Binary classification task" method: bayes metric: goal: maximize name: test_acc parameters: + model: + values: [ resnet_binary , convnext_binary ] optim: values: [ adamw , lion ] scheduler: values: [ exp_decay , cosine_anneal_warmup ] optim.lr: - values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ] \ No newline at end of file + values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ] +early_terminate: + type: hyperband + s: 2 + eta: 3 + max_iter: 27 +run_cap: 50 \ No newline at end of file diff --git a/sweep.py b/sweep.py index 6c383bf..b2adfdb 100644 --- a/sweep.py +++ b/sweep.py @@ -25,8 +25,12 @@ def parse_args(): parser.add_argument("--version_base", default="1.3", type=str, help="") parser.add_argument("--sweep_cfg_name", default="sweep.yaml", type=str, help="") - parser.add_argument("--wandb_project", default="brain-age", type=str, help="") + parser.add_argument("--wandb_project", default="brain-age", type=str, + help="Project name for training. Since we are using sweep, it is recommended to avoid `brain-age`\ + and rather use dataset name as project name.") parser.add_argument("--wandb_entity", default="1pha", type=str, help="") + parser.add_argument("--sweep_prefix", default="", type=str, + help="Prefix for sweep experiment run name.") args = parser.parse_args() return args @@ -49,8 +53,8 @@ def load_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml" return sweep_cfg -def override_config(hydra_config: omegaconf.DictConfig, - update_dict: dict, config_path: str = "config") -> omegaconf.DictConfig: +def override_config(hydra_config: omegaconf.DictConfig, update_dict: dict, + config_path: str = "config", prefix: str = "") -> omegaconf.DictConfig: """ hydra_config: Base config update_dict : Updated key-value pairs which should be merged into hydra_config. @@ -73,7 +77,7 @@ def override_config(hydra_config: omegaconf.DictConfig, _c = _c[_k] else: _c[_k] = value - + var_sweep = " ".join([f"{k[:3]}={v}" for k, v in update_dict.items()]) ds_name = sage.utils.get_func_name(hydra_config.dataset) if hydra_config.get("dataset") else "" if "sweep" in hydra_config.get("hydra", []): @@ -82,17 +86,16 @@ def override_config(hydra_config: omegaconf.DictConfig, hydra_config.hydra.sweep.subdir = var_sweep dirpath = f"{hydra_config.hydra.sweep.dir}/{var_sweep}" hydra_config.callbacks.checkpoint.dirpath = dirpath - hydra_config.logger.name = f"{var_sweep}" + hydra_config.logger.name = f"{prefix}: {var_sweep}" if prefix else var_sweep return hydra_config -def main(config: omegaconf.DictConfig, config_path: str = "config") -> float: +def main(config: omegaconf.DictConfig, config_path: str = "config", prefix: str = "") -> float: wandb.init(project="brain-age") _config = deepcopy(config) - updated_config = override_config(hydra_config=_config, - update_dict=wandb.config, - config_path=config_path) + updated_config = override_config(hydra_config=_config, update_dict=wandb.config, + config_path=config_path, prefix=prefix) wandb.run.name = updated_config.logger.name logger.info("Start Training") @@ -105,13 +108,14 @@ def main(config: omegaconf.DictConfig, config_path: str = "config") -> float: args = parse_args() # Load hydra default configuration - overrides = ast.literal_eval(args.overrides) + overrides = ast.literal_eval(args.overrides or "[]") config = load_hydra_config(config_path=args.config_path, config_name=args.config_name, overrides=overrides, version_base=args.version_base, return_hydra_config=True) - func: Callable = partial(main, config=config, config_path=args.config_path) + func: Callable = partial(main, config=config, config_path=args.config_path, + prefix=args.sweep_prefix) # Load wandb.sweep configuration and instantiation sweep_cfg = load_yaml(config_path=os.path.join(args.config_path, "sweep"), diff --git a/sweep_command.sh b/sweep_command.sh index bca2fda..7ca424b 100755 --- a/sweep_command.sh +++ b/sweep_command.sh @@ -4,5 +4,6 @@ export CUDA_VISIBLE_DEVICES=1 python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\ --wandb_project=ppmi\ --config_name=train_binary.yaml\ - --overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\ - '+module.load_model_strict=False']" \ No newline at end of file + --sweep_prefix='Scratch' + # --overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\ + # '+module.load_model_strict=False']" \ No newline at end of file