Skip to content

Commit

Permalink
Add wandb prefix. Remove pre-trained checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 13, 2024
1 parent 5649f2d commit c2abd9a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
12 changes: 11 additions & 1 deletion config/sweep/ppmi_sweep.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
name: "PPMI Sweep"
description: "HPO for PPMI-Binary classification task"
method: bayes
metric:
goal: maximize
name: test_acc
parameters:
model:
values: [ resnet_binary , convnext_binary ]
optim:
values: [ adamw , lion ]
scheduler:
values: [ exp_decay , cosine_anneal_warmup ]
optim.lr:
values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ]
values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ]
early_terminate:
type: hyperband
s: 2
eta: 3
max_iter: 27
run_cap: 50
26 changes: 15 additions & 11 deletions sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def parse_args():
parser.add_argument("--version_base", default="1.3", type=str, help="")

parser.add_argument("--sweep_cfg_name", default="sweep.yaml", type=str, help="")
parser.add_argument("--wandb_project", default="brain-age", type=str, help="")
parser.add_argument("--wandb_project", default="brain-age", type=str,
help="Project name for training. Since we are using sweep, it is recommended to avoid `brain-age`\
and rather use dataset name as project name.")
parser.add_argument("--wandb_entity", default="1pha", type=str, help="")
parser.add_argument("--sweep_prefix", default="", type=str,
help="Prefix for sweep experiment run name.")

args = parser.parse_args()
return args
Expand All @@ -49,8 +53,8 @@ def load_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml"
return sweep_cfg


def override_config(hydra_config: omegaconf.DictConfig,
update_dict: dict, config_path: str = "config") -> omegaconf.DictConfig:
def override_config(hydra_config: omegaconf.DictConfig, update_dict: dict,
config_path: str = "config", prefix: str = "") -> omegaconf.DictConfig:
"""
hydra_config: Base config
update_dict : Updated key-value pairs which should be merged into hydra_config.
Expand All @@ -73,7 +77,7 @@ def override_config(hydra_config: omegaconf.DictConfig,
_c = _c[_k]
else:
_c[_k] = value

var_sweep = " ".join([f"{k[:3]}={v}" for k, v in update_dict.items()])
ds_name = sage.utils.get_func_name(hydra_config.dataset) if hydra_config.get("dataset") else ""
if "sweep" in hydra_config.get("hydra", []):
Expand All @@ -82,17 +86,16 @@ def override_config(hydra_config: omegaconf.DictConfig,
hydra_config.hydra.sweep.subdir = var_sweep
dirpath = f"{hydra_config.hydra.sweep.dir}/{var_sweep}"
hydra_config.callbacks.checkpoint.dirpath = dirpath
hydra_config.logger.name = f"{var_sweep}"
hydra_config.logger.name = f"{prefix}: {var_sweep}" if prefix else var_sweep

return hydra_config


def main(config: omegaconf.DictConfig, config_path: str = "config") -> float:
def main(config: omegaconf.DictConfig, config_path: str = "config", prefix: str = "") -> float:
wandb.init(project="brain-age")
_config = deepcopy(config)
updated_config = override_config(hydra_config=_config,
update_dict=wandb.config,
config_path=config_path)
updated_config = override_config(hydra_config=_config, update_dict=wandb.config,
config_path=config_path, prefix=prefix)
wandb.run.name = updated_config.logger.name

logger.info("Start Training")
Expand All @@ -105,13 +108,14 @@ def main(config: omegaconf.DictConfig, config_path: str = "config") -> float:
args = parse_args()

# Load hydra default configuration
overrides = ast.literal_eval(args.overrides)
overrides = ast.literal_eval(args.overrides or "[]")
config = load_hydra_config(config_path=args.config_path,
config_name=args.config_name,
overrides=overrides,
version_base=args.version_base,
return_hydra_config=True)
func: Callable = partial(main, config=config, config_path=args.config_path)
func: Callable = partial(main, config=config, config_path=args.config_path,
prefix=args.sweep_prefix)

# Load wandb.sweep configuration and instantiation
sweep_cfg = load_yaml(config_path=os.path.join(args.config_path, "sweep"),
Expand Down
5 changes: 3 additions & 2 deletions sweep_command.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ export CUDA_VISIBLE_DEVICES=1
python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\
--wandb_project=ppmi\
--config_name=train_binary.yaml\
--overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\
'+module.load_model_strict=False']"
--sweep_prefix='Scratch'
# --overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\
# '+module.load_model_strict=False']"

0 comments on commit c2abd9a

Please sign in to comment.