Skip to content

Commit

Permalink
Finalize wandb sweep integration with hydra & lightning :)
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 13, 2024
1 parent e9908b2 commit f5c414e
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 72 deletions.
2 changes: 1 addition & 1 deletion config/callbacks/checkpoint/binary.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra.run.dir}
dirpath: ${hydra:run.dir}
filename: "{step}-valid_f1-{epoch/valid_BinaryF1Score:.3f}"
monitor: epoch/valid_BinaryF1Score
mode: max
Expand Down
2 changes: 1 addition & 1 deletion config/callbacks/checkpoint/cls.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra.run.dir}
dirpath: ${hydra:run.dir}
filename: "{step}-valid_f1-{epoch/valid_MulticlassF1Score:.3f}"
monitor: epoch/valid_MulticlassF1Score
mode: max
Expand Down
12 changes: 7 additions & 5 deletions config/sweep/ppmi_sweep.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
method: random
method: bayes
metric:
goal: maximize
name: metric
name: test_acc
parameters:
model:
values: [ resnet_binary , convnext_binary ]
optim:
values: [ adamw , lion ]
scheduler:
values: [ exp_decay , cosine_anneal_warmup ]
optim.lr:
values: [ 5e-3 , 1e-3 , 5e-4 , 1e-4 , 5e-5 ]
values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ]
6 changes: 3 additions & 3 deletions config/train_binary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}-${logger.name}
sweep:
dir: ${hydra.run.dir}
dir: ${hydra:run.dir}

defaults:
- _self_
- model: convnext_cls
- dataset: ppmi
- model: resnet_binary
- dataset: ppmi_binary
- scheduler: exp_decay
- optim: adamw
- trainer: default
Expand Down
4 changes: 2 additions & 2 deletions config/trainer/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ gradient_clip_val: 1
log_every_n_steps: 50
accumulate_grad_batches: 4
fast_dev_run: True
limit_train_batches: 0.001
limit_val_batches: 0.01
limit_train_batches: 2
limit_val_batches: 2
2 changes: 1 addition & 1 deletion meta_brain
4 changes: 2 additions & 2 deletions sage/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def _forward(self, brain: torch.Tensor):
def num_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)

def load_from_checkpoint(self, ckpt: str):
def load_from_checkpoint(self, ckpt: str, strict: bool = True):
ckpt = torch.load(ckpt)["state_dict"]
def parse_ckpt(s: str):
# This is to remove "model." prefix from pytorch_lightning
s = ".".join(s.split(".")[1:])
return s
ckpt = {parse_ckpt(k): v for k, v in ckpt.items()}
self.load_state_dict(ckpt)
self.load_state_dict(ckpt, strict=strict)

def conv_layers(self):
if hasattr(self.backbone, "conv_layers"):
Expand Down
76 changes: 54 additions & 22 deletions sage/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self,
augmentation: omegaconf.DictConfig = None,
scheduler: omegaconf.DictConfig = None,
load_model_ckpt: str = None,
load_model_strict: bool = True,
load_from_checkpoint: str = None, # unused params but requires for instantiation
separate_lr: dict = None,
task: str = None,
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self,

if load_model_ckpt:
logger.info("Load checkpoint from %s", load_model_ckpt)
self.model.load_from_checkpoint(load_model_ckpt)
self.model.load_from_checkpoint(ckpt=load_model_ckpt, strict=load_model_strict)

self.log_train_metrics = log_train_metrics
self.log_lr = manual_lr
Expand Down Expand Up @@ -132,7 +133,7 @@ def _configure_optimizer(self,
if submodel is None:
logger.warn("separate_lr was given but submodel was not found: %s", _submodel)
opt_config = self._configure_optimizer(optimizer=optimizer,
scheduler=scheduler)
scheduler=scheduler)
break
_opt_groups.append(
{"params": submodel.parameters(), "lr": _lr}
Expand Down Expand Up @@ -174,8 +175,8 @@ def configure_scheduler(self,
if num_training_steps:
struct.update({"num_training_steps": num_training_steps})
sch = hydra.utils.instantiate(scheduler, scheduler=struct)
except Exception as e:
logger.exception(e)
except TypeError as e:
breakpoint()
raise
return sch

Expand Down Expand Up @@ -242,6 +243,26 @@ def log_confusion_matrix(self, result: dict):
pr = wandb.plot.pr_curve(y_true=labels, y_probas=probs)
self.logger.experiment.log({"confusion_matrix": cf, "roc_curve": roc, "pr_curve": pr})

def log_table(self, batch: Dict[str, torch.Tensor], result: Dict[str, torch.Tensor]):
""" Preparing table logging to wandb. """
if not hasattr(self, "table_columns"):
self.table_columns = ["PID", "Image", "Target", "Prediction", "Entropy"] + \
[f"Logit {c}" for c in range(result["cls_pred"].size(1))]
if not hasattr(self, "table_data"):
self.table_data = []

img_path, img = batch["image_path"], batch["image"]
for i, ind in enumerate(batch["indicator"]):
x, path = img[:ind], img_path[:ind]
pred = result["cls_pred"][i]
prediction = int(pred.argmax())
entropy = -float((pred * pred.log()).sum())
pred, target = pred.tolist(), int(result["cls_target"][i])
self.table_data.append(
["\n".join(path), wandb.Image(x), target, prediction, entropy] + pred
)
img, img_path = img[ind:], img_path[ind:]

def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False):
output = {f"{unit}/{k}": float(v) for k, v in output.items()}
self.log_dict(dictionary=output,
Expand Down Expand Up @@ -298,6 +319,9 @@ def on_predict_end(self):
result = utils._sort_outputs(outputs=self.prediction_step_outputs)
if utils.check_classification(result=result):
self.log_confusion_matrix(result=result)
self.logger.log_table(key="Test Prediction", columns=["Target", "Prediction"],
data=[(t, p) for t, p in zip(result["target"].tolist(),
result["pred"].tolist())])

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
result: dict = self.forward(batch, mode="test")
Expand Down Expand Up @@ -360,48 +384,56 @@ def tune(config: omegaconf.DictConfig) -> omegaconf.DictConfig:
return config


def train(config: omegaconf.DictConfig) -> None:
def train(config: omegaconf.DictConfig) -> Dict[str, float]:
config: omegaconf.DictConfig = tune(config)
_logger = hydra.utils.instantiate(config.logger)
_logger: pl._logger = hydra.utils.instantiate(config.logger)
module, dataloaders = setup_trainer(config)

# Logger Setup
_logger.watch(module)
config_update: bool = "version" in config.logger or config.trainer.devices > 1
config_update: bool = "version" in config.logger or (config.trainer.devices > 1)
if config_update:
# Skip config update when using resume checkpoint
# or when using multiple devices
pass
else:
# Hard-code config uploading
resolve = not "sweep" in config.hydra
wandb.config.update(
omegaconf.OmegaConf.to_container(config, resolve=resolve, throw_on_missing=resolve)
)
resolve = not "sweep" in config.get("hydra", [])
_cfg = omegaconf.OmegaConf.to_container(config, resolve=resolve, throw_on_missing=resolve)
wandb.config.update(_cfg)

# Callbacks
callbacks: dict = hydra.utils.instantiate(config.callbacks)
callbacks: Dict[str, pl.Callback] = hydra.utils.instantiate(config.callbacks)
trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, logger=_logger,
callbacks=list(callbacks.values()))
trainer.fit(model=module,
train_dataloaders=dataloaders["train"],
val_dataloaders=dataloaders["valid"],
ckpt_path=config.misc.get("ckpt_path", None))

if dataloaders["test"]:
logger.info("Test dataset given. Start inference on %s", len(dataloaders["test"].dataset))
prediction = trainer.predict(ckpt_path="best", dataloaders=dataloaders["test"])
metric = utils.finalize_inference(prediction=prediction, name=config.logger.name,
root_dir=Path(config.callbacks.checkpoint.dirpath))
dl = dataloaders["test"]
elif dataloaders["valid"]:
logger.info("Test dataset not found. Start inference on validation dataset %s",
len(dataloaders["valid"].dataset))
prediction = trainer.predict(ckpt_path="best", dataloaders=dataloaders["valid"])
metric = utils.finalize_inference(prediction=prediction, name=config.logger.name,
root_dir=Path(config.callbacks.checkpoint.dirpath))

dl = dataloaders["valid"]
else:
logger.info("No valid or test dataset found. Skip inference")
dl = None

# Make Prediction & Log Metrics to wandb logger
if dl is not None:
prediction = trainer.predict(ckpt_path="best", dataloaders=dl)
metric: Dict[str, float] = utils.finalize_inference(prediction=prediction,
name=config.logger.name,
root_dir=Path(config.callbacks.checkpoint.dirpath))
trainer.logger.log_metrics(metric)

if config_update:
wandb.config.update(omegaconf.OmegaConf.to_container(config, resolve=True,
throw_on_missing=True))
# Update configuration if needed
_cfg = omegaconf.OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
wandb.config.update(_cfg)
return metric


Expand Down
30 changes: 23 additions & 7 deletions sage/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from pathlib import Path
import pickle
from typing import Dict, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -54,7 +55,7 @@ def tune(batch_size: int = 64,
return ratio


def _sort_outputs(outputs):
def _sort_outputs(outputs: Dict[str, list]):
result = dict()
keys: list = outputs[0].keys()
for key in keys:
Expand Down Expand Up @@ -94,7 +95,7 @@ def finalize_inference(prediction: list,
pickle.dump(prediction, f)

# 2. Log Predictions
run_name = save_name[:-4] + "_" + timestamp()
run_name: str = save_name[:-4] + "_" + timestamp()
preds, target = prediction["pred"], prediction["target"]
infer_kwargs = dict(preds=preds, target=target, root_dir=root_dir, run_name=run_name)
result = dict(pred=preds, target=target)
Expand All @@ -108,7 +109,11 @@ def finalize_inference(prediction: list,
return metric


def _reg_inference(preds, target, root_dir, run_name) -> float:
def _reg_inference(preds: List[torch.Tensor],
target: List[torch.Tensor],
root_dir: Path,
run_name: str,
prefix: str = "test") -> Dict[str, float]:
mse = tmf.mean_squared_error(preds=preds, target=target)
mae = tmf.mean_absolute_error(preds=preds, target=target)
r2 = tmf.r2_score(preds=preds, target=target)
Expand All @@ -135,10 +140,15 @@ def _reg_inference(preds, target, root_dir, run_name) -> float:
fig.tight_layout()
fig.savefig(root_dir / f"{run_name}-kde.png")
plt.close()
return mse

metric = {f"{prefix}_{m}": v for m, v in zip(["mse", "mae", "r2"], [mse, mae, r2])}
return metric


def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None:
def _get_norm_cf_reg(preds: List[torch.Tensor],
target: List[torch.Tensor],
root_dir: Path,
run_name: str) -> None:
"""Calculate normalized confusion matrix for regression result.
Calculating number of bins is done autmoatically.
1. Number of bins: between 5 and 10
Expand Down Expand Up @@ -198,7 +208,11 @@ def _get_norm_cf_reg(preds, target, root_dir, run_name) -> None:
plt.close()


def _cls_inference(preds, target, root_dir, run_name) -> float:
def _cls_inference(preds: List[torch.Tensor],
target: List[torch.Tensor],
root_dir: Path,
run_name: str,
prefix: str = "test") -> Dict[str, float]:
metrics_input = dict(preds=preds,
target=target.int(),
task="binary")
Expand All @@ -215,7 +229,9 @@ def _cls_inference(preds, target, root_dir, run_name) -> float:
p.set_title(run_name)
plt.savefig(root_dir / f"{run_name}-cf.png")
plt.close()
return acc

metric = {f"{prefix}_{m}": v for m, v in zip(["acc", "f1", "auroc"], [acc, f1, auroc])}
return metric


def brain2augment(brain: torch.Tensor) -> torch.Tensor:
Expand Down
9 changes: 9 additions & 0 deletions sage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def parse_hydra(config: omegaconf, **kwargs):
return inst


def get_func_name(config: omegaconf.DictConfig) -> str:
if "_target_" in config:
target = config._target_
name = target.split(".")[-1]
else:
name = ""
return name


def load_hydra(config_name: str,
config_path: str = "config",
overrides: FrozenSet[str] = None):
Expand Down
Loading

0 comments on commit f5c414e

Please sign in to comment.