From f5c414e55b5f2dc91eb119687e0fe65c23d080ec Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Wed, 13 Mar 2024 05:42:09 +0000 Subject: [PATCH] Finalize wandb sweep integration with hydra & lightning :) --- config/callbacks/checkpoint/binary.yaml | 2 +- config/callbacks/checkpoint/cls.yaml | 2 +- config/sweep/ppmi_sweep.yaml | 12 ++-- config/train_binary.yaml | 6 +- config/trainer/debug.yaml | 4 +- meta_brain | 2 +- sage/models/base.py | 4 +- sage/trainer/trainer.py | 76 ++++++++++++++++++------- sage/trainer/utils.py | 30 +++++++--- sage/utils.py | 9 +++ sweep.py | 66 ++++++++++++--------- sweep_command.sh | 4 +- train.py | 2 +- 13 files changed, 147 insertions(+), 72 deletions(-) diff --git a/config/callbacks/checkpoint/binary.yaml b/config/callbacks/checkpoint/binary.yaml index 0738449..d5f1a6f 100644 --- a/config/callbacks/checkpoint/binary.yaml +++ b/config/callbacks/checkpoint/binary.yaml @@ -1,5 +1,5 @@ _target_: pytorch_lightning.callbacks.ModelCheckpoint -dirpath: ${hydra.run.dir} +dirpath: ${hydra:run.dir} filename: "{step}-valid_f1-{epoch/valid_BinaryF1Score:.3f}" monitor: epoch/valid_BinaryF1Score mode: max diff --git a/config/callbacks/checkpoint/cls.yaml b/config/callbacks/checkpoint/cls.yaml index 4e3d098..0361bbf 100644 --- a/config/callbacks/checkpoint/cls.yaml +++ b/config/callbacks/checkpoint/cls.yaml @@ -1,5 +1,5 @@ _target_: pytorch_lightning.callbacks.ModelCheckpoint -dirpath: ${hydra.run.dir} +dirpath: ${hydra:run.dir} filename: "{step}-valid_f1-{epoch/valid_MulticlassF1Score:.3f}" monitor: epoch/valid_MulticlassF1Score mode: max diff --git a/config/sweep/ppmi_sweep.yaml b/config/sweep/ppmi_sweep.yaml index f812a5a..55a9043 100644 --- a/config/sweep/ppmi_sweep.yaml +++ b/config/sweep/ppmi_sweep.yaml @@ -1,9 +1,11 @@ -method: random +method: bayes metric: goal: maximize - name: metric + name: test_acc parameters: - model: - values: [ resnet_binary , convnext_binary ] + optim: + values: [ adamw , lion ] + scheduler: + values: [ exp_decay , cosine_anneal_warmup ] optim.lr: - values: [ 5e-3 , 1e-3 , 5e-4 , 1e-4 , 5e-5 ] \ No newline at end of file + values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ] \ No newline at end of file diff --git a/config/train_binary.yaml b/config/train_binary.yaml index 0bd5507..5bfd9fc 100644 --- a/config/train_binary.yaml +++ b/config/train_binary.yaml @@ -4,12 +4,12 @@ hydra: run: dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${logger.name} sweep: - dir: ${hydra.run.dir} + dir: ${hydra:run.dir} defaults: - _self_ - - model: convnext_cls - - dataset: ppmi + - model: resnet_binary + - dataset: ppmi_binary - scheduler: exp_decay - optim: adamw - trainer: default diff --git a/config/trainer/debug.yaml b/config/trainer/debug.yaml index d19b847..4a0d6de 100644 --- a/config/trainer/debug.yaml +++ b/config/trainer/debug.yaml @@ -6,5 +6,5 @@ gradient_clip_val: 1 log_every_n_steps: 50 accumulate_grad_batches: 4 fast_dev_run: True -limit_train_batches: 0.001 -limit_val_batches: 0.01 \ No newline at end of file +limit_train_batches: 2 +limit_val_batches: 2 \ No newline at end of file diff --git a/meta_brain b/meta_brain index 39b7391..a315502 160000 --- a/meta_brain +++ b/meta_brain @@ -1 +1 @@ -Subproject commit 39b7391de129813139e87d76009356035f77138c +Subproject commit a3155020ca2bd85539d48647c56504a12d2f8f9f diff --git a/sage/models/base.py b/sage/models/base.py index 23597ff..15b780c 100644 --- a/sage/models/base.py +++ b/sage/models/base.py @@ -25,14 +25,14 @@ def _forward(self, brain: torch.Tensor): def num_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) - def load_from_checkpoint(self, ckpt: str): + def load_from_checkpoint(self, ckpt: str, strict: bool = True): ckpt = torch.load(ckpt)["state_dict"] def parse_ckpt(s: str): # This is to remove "model." prefix from pytorch_lightning s = ".".join(s.split(".")[1:]) return s ckpt = {parse_ckpt(k): v for k, v in ckpt.items()} - self.load_state_dict(ckpt) + self.load_state_dict(ckpt, strict=strict) def conv_layers(self): if hasattr(self.backbone, "conv_layers"): diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index fa54594..611ae9f 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -36,6 +36,7 @@ def __init__(self, augmentation: omegaconf.DictConfig = None, scheduler: omegaconf.DictConfig = None, load_model_ckpt: str = None, + load_model_strict: bool = True, load_from_checkpoint: str = None, # unused params but requires for instantiation separate_lr: dict = None, task: str = None, @@ -67,7 +68,7 @@ def __init__(self, if load_model_ckpt: logger.info("Load checkpoint from %s", load_model_ckpt) - self.model.load_from_checkpoint(load_model_ckpt) + self.model.load_from_checkpoint(ckpt=load_model_ckpt, strict=load_model_strict) self.log_train_metrics = log_train_metrics self.log_lr = manual_lr @@ -132,7 +133,7 @@ def _configure_optimizer(self, if submodel is None: logger.warn("separate_lr was given but submodel was not found: %s", _submodel) opt_config = self._configure_optimizer(optimizer=optimizer, - scheduler=scheduler) + scheduler=scheduler) break _opt_groups.append( {"params": submodel.parameters(), "lr": _lr} @@ -174,8 +175,8 @@ def configure_scheduler(self, if num_training_steps: struct.update({"num_training_steps": num_training_steps}) sch = hydra.utils.instantiate(scheduler, scheduler=struct) - except Exception as e: - logger.exception(e) + except TypeError as e: + breakpoint() raise return sch @@ -242,6 +243,26 @@ def log_confusion_matrix(self, result: dict): pr = wandb.plot.pr_curve(y_true=labels, y_probas=probs) self.logger.experiment.log({"confusion_matrix": cf, "roc_curve": roc, "pr_curve": pr}) + def log_table(self, batch: Dict[str, torch.Tensor], result: Dict[str, torch.Tensor]): + """ Preparing table logging to wandb. """ + if not hasattr(self, "table_columns"): + self.table_columns = ["PID", "Image", "Target", "Prediction", "Entropy"] + \ + [f"Logit {c}" for c in range(result["cls_pred"].size(1))] + if not hasattr(self, "table_data"): + self.table_data = [] + + img_path, img = batch["image_path"], batch["image"] + for i, ind in enumerate(batch["indicator"]): + x, path = img[:ind], img_path[:ind] + pred = result["cls_pred"][i] + prediction = int(pred.argmax()) + entropy = -float((pred * pred.log()).sum()) + pred, target = pred.tolist(), int(result["cls_target"][i]) + self.table_data.append( + ["\n".join(path), wandb.Image(x), target, prediction, entropy] + pred + ) + img, img_path = img[ind:], img_path[ind:] + def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False): output = {f"{unit}/{k}": float(v) for k, v in output.items()} self.log_dict(dictionary=output, @@ -298,6 +319,9 @@ def on_predict_end(self): result = utils._sort_outputs(outputs=self.prediction_step_outputs) if utils.check_classification(result=result): self.log_confusion_matrix(result=result) + self.logger.log_table(key="Test Prediction", columns=["Target", "Prediction"], + data=[(t, p) for t, p in zip(result["target"].tolist(), + result["pred"].tolist())]) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): result: dict = self.forward(batch, mode="test") @@ -360,48 +384,56 @@ def tune(config: omegaconf.DictConfig) -> omegaconf.DictConfig: return config -def train(config: omegaconf.DictConfig) -> None: +def train(config: omegaconf.DictConfig) -> Dict[str, float]: config: omegaconf.DictConfig = tune(config) - _logger = hydra.utils.instantiate(config.logger) + _logger: pl._logger = hydra.utils.instantiate(config.logger) module, dataloaders = setup_trainer(config) # Logger Setup _logger.watch(module) - config_update: bool = "version" in config.logger or config.trainer.devices > 1 + config_update: bool = "version" in config.logger or (config.trainer.devices > 1) if config_update: # Skip config update when using resume checkpoint + # or when using multiple devices pass else: # Hard-code config uploading - resolve = not "sweep" in config.hydra - wandb.config.update( - omegaconf.OmegaConf.to_container(config, resolve=resolve, throw_on_missing=resolve) - ) + resolve = not "sweep" in config.get("hydra", []) + _cfg = omegaconf.OmegaConf.to_container(config, resolve=resolve, throw_on_missing=resolve) + wandb.config.update(_cfg) # Callbacks - callbacks: dict = hydra.utils.instantiate(config.callbacks) + callbacks: Dict[str, pl.Callback] = hydra.utils.instantiate(config.callbacks) trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, logger=_logger, callbacks=list(callbacks.values())) trainer.fit(model=module, train_dataloaders=dataloaders["train"], val_dataloaders=dataloaders["valid"], ckpt_path=config.misc.get("ckpt_path", None)) - + if dataloaders["test"]: logger.info("Test dataset given. Start inference on %s", len(dataloaders["test"].dataset)) - prediction = trainer.predict(ckpt_path="best", dataloaders=dataloaders["test"]) - metric = utils.finalize_inference(prediction=prediction, name=config.logger.name, - root_dir=Path(config.callbacks.checkpoint.dirpath)) + dl = dataloaders["test"] elif dataloaders["valid"]: logger.info("Test dataset not found. Start inference on validation dataset %s", len(dataloaders["valid"].dataset)) - prediction = trainer.predict(ckpt_path="best", dataloaders=dataloaders["valid"]) - metric = utils.finalize_inference(prediction=prediction, name=config.logger.name, - root_dir=Path(config.callbacks.checkpoint.dirpath)) - + dl = dataloaders["valid"] + else: + logger.info("No valid or test dataset found. Skip inference") + dl = None + + # Make Prediction & Log Metrics to wandb logger + if dl is not None: + prediction = trainer.predict(ckpt_path="best", dataloaders=dl) + metric: Dict[str, float] = utils.finalize_inference(prediction=prediction, + name=config.logger.name, + root_dir=Path(config.callbacks.checkpoint.dirpath)) + trainer.logger.log_metrics(metric) + if config_update: - wandb.config.update(omegaconf.OmegaConf.to_container(config, resolve=True, - throw_on_missing=True)) + # Update configuration if needed + _cfg = omegaconf.OmegaConf.to_container(config, resolve=True, throw_on_missing=True) + wandb.config.update(_cfg) return metric diff --git a/sage/trainer/utils.py b/sage/trainer/utils.py index 27eee2a..405902d 100644 --- a/sage/trainer/utils.py +++ b/sage/trainer/utils.py @@ -4,6 +4,7 @@ from datetime import datetime from pathlib import Path import pickle +from typing import Dict, List import numpy as np import pandas as pd @@ -54,7 +55,7 @@ def tune(batch_size: int = 64, return ratio -def _sort_outputs(outputs): +def _sort_outputs(outputs: Dict[str, list]): result = dict() keys: list = outputs[0].keys() for key in keys: @@ -94,7 +95,7 @@ def finalize_inference(prediction: list, pickle.dump(prediction, f) # 2. Log Predictions - run_name = save_name[:-4] + "_" + timestamp() + run_name: str = save_name[:-4] + "_" + timestamp() preds, target = prediction["pred"], prediction["target"] infer_kwargs = dict(preds=preds, target=target, root_dir=root_dir, run_name=run_name) result = dict(pred=preds, target=target) @@ -108,7 +109,11 @@ def finalize_inference(prediction: list, return metric -def _reg_inference(preds, target, root_dir, run_name) -> float: +def _reg_inference(preds: List[torch.Tensor], + target: List[torch.Tensor], + root_dir: Path, + run_name: str, + prefix: str = "test") -> Dict[str, float]: mse = tmf.mean_squared_error(preds=preds, target=target) mae = tmf.mean_absolute_error(preds=preds, target=target) r2 = tmf.r2_score(preds=preds, target=target) @@ -135,10 +140,15 @@ def _reg_inference(preds, target, root_dir, run_name) -> float: fig.tight_layout() fig.savefig(root_dir / f"{run_name}-kde.png") plt.close() - return mse + metric = {f"{prefix}_{m}": v for m, v in zip(["mse", "mae", "r2"], [mse, mae, r2])} + return metric + -def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None: +def _get_norm_cf_reg(preds: List[torch.Tensor], + target: List[torch.Tensor], + root_dir: Path, + run_name: str) -> None: """Calculate normalized confusion matrix for regression result. Calculating number of bins is done autmoatically. 1. Number of bins: between 5 and 10 @@ -198,7 +208,11 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None: plt.close() -def _cls_inference(preds, target, root_dir, run_name) -> float: +def _cls_inference(preds: List[torch.Tensor], + target: List[torch.Tensor], + root_dir: Path, + run_name: str, + prefix: str = "test") -> Dict[str, float]: metrics_input = dict(preds=preds, target=target.int(), task="binary") @@ -215,7 +229,9 @@ def _cls_inference(preds, target, root_dir, run_name) -> float: p.set_title(run_name) plt.savefig(root_dir / f"{run_name}-cf.png") plt.close() - return acc + + metric = {f"{prefix}_{m}": v for m, v in zip(["acc", "f1", "auroc"], [acc, f1, auroc])} + return metric def brain2augment(brain: torch.Tensor) -> torch.Tensor: diff --git a/sage/utils.py b/sage/utils.py index 6837414..f4ed79e 100644 --- a/sage/utils.py +++ b/sage/utils.py @@ -79,6 +79,15 @@ def parse_hydra(config: omegaconf, **kwargs): return inst +def get_func_name(config: omegaconf.DictConfig) -> str: + if "_target_" in config: + target = config._target_ + name = target.split(".")[-1] + else: + name = "" + return name + + def load_hydra(config_name: str, config_path: str = "config", overrides: FrozenSet[str] = None): diff --git a/sweep.py b/sweep.py index 684815d..40862ac 100644 --- a/sweep.py +++ b/sweep.py @@ -1,5 +1,6 @@ import os import ast +from copy import deepcopy import argparse from functools import partial from typing import List, Callable @@ -21,7 +22,7 @@ def parse_args(): parser.add_argument("--config_path", default="config", type=str, help="") parser.add_argument("--config_name", default="train.yaml", type=str, help="") parser.add_argument("--overrides", default="", type=str, help="") - parser.add_argument("--version_base", default="1.1", type=str, help="") + parser.add_argument("--version_base", default="1.3", type=str, help="") parser.add_argument("--sweep_cfg_name", default="sweep.yaml", type=str, help="") parser.add_argument("--wandb_project", default="brain-age", type=str, help="") @@ -31,6 +32,23 @@ def parse_args(): return args +def load_hydra_config(config_path: str = "config", + config_name: str = "train.yaml", + version_base="1.3", + overrides: List[str] = [], + return_hydra_config: bool = False) -> omegaconf.DictConfig: + with hydra.initialize(config_path=config_path, version_base=version_base): + config = hydra.compose(config_name=config_name, overrides=overrides, + return_hydra_config=return_hydra_config) + return config + + +def load_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml") -> dict: + with open(os.path.join(config_path, config_name), mode="r") as f: + sweep_cfg = yaml.load(stream=f, Loader=yaml.FullLoader) + return sweep_cfg + + def override_config(hydra_config: omegaconf.DictConfig, update_dict: dict, config_path: str = "config") -> omegaconf.DictConfig: """ @@ -45,7 +63,8 @@ def override_config(hydra_config: omegaconf.DictConfig, if nkeys == 1: # If no . found in key # This implies override from defaults - _subcfg = load_yaml(config_path=f"{config_path}/{key}", config_name=f"{value}.yaml") + _subcfg = load_hydra_config(config_path=f"{config_path}/{key}", + config_name=f"{value}.yaml") hydra_config[key] = _subcfg else: _c = hydra_config[key_list[0]] @@ -54,36 +73,30 @@ def override_config(hydra_config: omegaconf.DictConfig, _c = _c[_k] else: _c[_k] = value - if "sweep" in hydra_config.hydra: + + var_sweep = " ".join([f"{k[:3]}={v}" for k, v in update_dict.items()]) + ds_name = sage.utils.get_func_name(hydra_config.dataset) if hydra_config.get("dataset") else "" + if "sweep" in hydra_config.get("hydra", []): # Configure directory for sweep. sweep_main_dir/subdir - hydra_config.hydra.sweep.subdir = "_".join([f"{k}={v}" for k, v in update_dict.items()]) - dirpath = f"{hydra_config.hydra.sweep.dir}/{hydra_config.hydra.sweep.subdir}" + hydra_config.hydra.sweep.dir = f"{hydra_config.hydra.run.dir}-{ds_name}" + hydra_config.hydra.sweep.subdir = var_sweep + dirpath = f"{hydra_config.hydra.sweep.dir}/{var_sweep}" hydra_config.callbacks.checkpoint.dirpath = dirpath - return hydra_config - + hydra_config.logger.name = f"{ds_name} {var_sweep}" -def load_default_hydra_config(config_path: str = "config", - config_name: str = "train.yaml", - version_base="1.1", - overrides: List[str] = []) -> omegaconf.DictConfig: - with hydra.initialize(config_path=config_path, version_base=version_base): - config = hydra.compose(config_name=config_name, overrides=overrides, return_hydra_config=True) - return config - - -def load_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml") -> dict: - with open(os.path.join(config_path, config_name), mode="r") as f: - sweep_cfg = yaml.load(stream=f, Loader=yaml.FullLoader) - return sweep_cfg + return hydra_config def main(config: omegaconf.DictConfig, config_path: str = "config") -> float: wandb.init(project="brain-age") - logger.info("Sweep Config: %s", wandb.config) - updated_config = override_config(hydra_config=config, + _config = deepcopy(config) + updated_config = override_config(hydra_config=_config, update_dict=wandb.config, config_path=config_path) + wandb.run.name = updated_config.logger.name + logger.info("Start Training") + logger.info("Sweep Config: %s", wandb.config) metric = sage.trainer.train(updated_config) return metric @@ -93,10 +106,11 @@ def main(config: omegaconf.DictConfig, config_path: str = "config") -> float: # Load hydra default configuration overrides = ast.literal_eval(args.overrides) - config = load_default_hydra_config(config_path=args.config_path, - config_name=args.config_name, - overrides=overrides, - version_base=args.version_base) + config = load_hydra_config(config_path=args.config_path, + config_name=args.config_name, + overrides=overrides, + version_base=args.version_base, + return_hydra_config=True) func: Callable = partial(main, config=config, config_path=args.config_path) # Load wandb.sweep configuration and instantiation diff --git a/sweep_command.sh b/sweep_command.sh index 933e56b..bca2fda 100755 --- a/sweep_command.sh +++ b/sweep_command.sh @@ -2,5 +2,7 @@ export HYDRA_FULL_ERROR=1 export CUDA_VISIBLE_DEVICES=1 python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\ + --wandb_project=ppmi\ --config_name=train_binary.yaml\ - --overrides="['dataset=ppmi_binary', 'model=convnext_binary']" \ No newline at end of file + --overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\ + '+module.load_model_strict=False']" \ No newline at end of file diff --git a/train.py b/train.py index 7914f82..ebf6290 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ logger = sage.utils.get_logger(name=__name__) -@hydra.main(config_path="config", config_name="train.yaml", version_base="1.1") +@hydra.main(config_path="config", config_name="train.yaml", version_base="1.3") def main(config: omegaconf.DictConfig): logger.info("Start Training") sage.trainer.train(config)