From f1c3138a3207a3ff74ffadafce39e921ca4e6f03 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Wed, 20 Mar 2024 04:42:11 +0000 Subject: [PATCH] Add PR/AUROC curve. Fix wandb log_table logic. Fix regression earlystop criteria --- config/callbacks/early_stop/binary.yaml | 8 ++--- config/callbacks/early_stop/reg.yaml | 2 +- config/sweep/adni_sweep.yaml | 21 +++++++++++ sage/data/adni.py | 1 - sage/data/ppmi.py | 1 - sage/models/base.py | 5 +-- sage/trainer/trainer.py | 26 +++----------- sweep_command.sh | 46 ++++++++++++++++++++----- 8 files changed, 71 insertions(+), 39 deletions(-) create mode 100644 config/sweep/adni_sweep.yaml diff --git a/config/callbacks/early_stop/binary.yaml b/config/callbacks/early_stop/binary.yaml index 6c90046..37a2251 100644 --- a/config/callbacks/early_stop/binary.yaml +++ b/config/callbacks/early_stop/binary.yaml @@ -1,4 +1,4 @@ - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: epoch/valid_BinaryF1Score - mode: max - patience: 10 \ No newline at end of file +_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping +monitor: epoch/valid_BinaryF1Score +mode: max +patience: 10 \ No newline at end of file diff --git a/config/callbacks/early_stop/reg.yaml b/config/callbacks/early_stop/reg.yaml index a6e77e7..5297170 100644 --- a/config/callbacks/early_stop/reg.yaml +++ b/config/callbacks/early_stop/reg.yaml @@ -1,4 +1,4 @@ _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping monitor: epoch/valid_MeanSquaredError -mode: max +mode: min patience: 10 \ No newline at end of file diff --git a/config/sweep/adni_sweep.yaml b/config/sweep/adni_sweep.yaml new file mode 100644 index 0000000..d534ec3 --- /dev/null +++ b/config/sweep/adni_sweep.yaml @@ -0,0 +1,21 @@ +name: "ADNI Sweep" +description: "HPO for ADNI classification task" +method: bayes +metric: + goal: maximize + name: test_acc +parameters: + model: + values: [ resnet_binary , convnext_binary ] + optim: + values: [ adamw , lion ] + scheduler: + values: [ exp_decay , cosine_anneal_warmup ] + optim.lr: + values: [ 5e-3 , 1e-3 , 1e-4 , 5e-5 ] +early_terminate: + type: hyperband + s: 2 + eta: 3 + max_iter: 27 +run_cap: 50 \ No newline at end of file diff --git a/sage/data/adni.py b/sage/data/adni.py index 7b0c687..ed30319 100644 --- a/sage/data/adni.py +++ b/sage/data/adni.py @@ -41,7 +41,6 @@ def _load_data(self, idx: int) -> Tuple[torch.Tensor]: """ Make sure to properly return PPMI """ raise NotImplementedError - @overrides def _exclude_data(self, labels: pd.DataFrame, pk_col: str, root: Path, exclusion_fname: str = "donotuse-adni.txt") -> pd.DataFrame: """ TODO: Remove exclude from label """ diff --git a/sage/data/ppmi.py b/sage/data/ppmi.py index 989d8e9..149b260 100644 --- a/sage/data/ppmi.py +++ b/sage/data/ppmi.py @@ -2,7 +2,6 @@ from typing import Tuple, List import torch -import pandas as pd from sage.data.dataloader import DatasetBase, open_scan import sage.constants as C diff --git a/sage/models/base.py b/sage/models/base.py index 15b780c..2b805fd 100644 --- a/sage/models/base.py +++ b/sage/models/base.py @@ -43,14 +43,15 @@ def conv_layers(self): class RegBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): - pred = self.backbone(brain).squeeze() + # Specify squeeze dimension to prevent batch_size=1 being squeezed to a singel scalar. + pred = self.backbone(brain).squeeze(dim=1) loss = self.criterion(pred, age.float()) return dict(loss=loss, pred=pred.detach(), target=age.detach()) class ClsBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): - pred = self.backbone(brain).squeeze() + pred = self.backbone(brain).squeeze(dim=1) loss = self.criterion(pred, age.long()) return dict(loss=loss, pred=pred.detach(), target=age.detach().long()) diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index 611ae9f..529332c 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -212,7 +212,7 @@ def forward(self, batch, mode: str = "train"): logger.exception(e) breakpoint() raise e - + def move_device(self, result: Dict[str, torch.Tensor], exclude_keys: List[str] = ["loss"]) -> Dict[str, torch.Tensor]: @@ -243,26 +243,6 @@ def log_confusion_matrix(self, result: dict): pr = wandb.plot.pr_curve(y_true=labels, y_probas=probs) self.logger.experiment.log({"confusion_matrix": cf, "roc_curve": roc, "pr_curve": pr}) - def log_table(self, batch: Dict[str, torch.Tensor], result: Dict[str, torch.Tensor]): - """ Preparing table logging to wandb. """ - if not hasattr(self, "table_columns"): - self.table_columns = ["PID", "Image", "Target", "Prediction", "Entropy"] + \ - [f"Logit {c}" for c in range(result["cls_pred"].size(1))] - if not hasattr(self, "table_data"): - self.table_data = [] - - img_path, img = batch["image_path"], batch["image"] - for i, ind in enumerate(batch["indicator"]): - x, path = img[:ind], img_path[:ind] - pred = result["cls_pred"][i] - prediction = int(pred.argmax()) - entropy = -float((pred * pred.log()).sum()) - pred, target = pred.tolist(), int(result["cls_target"][i]) - self.table_data.append( - ["\n".join(path), wandb.Image(x), target, prediction, entropy] + pred - ) - img, img_path = img[ind:], img_path[ind:] - def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False): output = {f"{unit}/{k}": float(v) for k, v in output.items()} self.log_dict(dictionary=output, @@ -321,7 +301,7 @@ def on_predict_end(self): self.log_confusion_matrix(result=result) self.logger.log_table(key="Test Prediction", columns=["Target", "Prediction"], data=[(t, p) for t, p in zip(result["target"].tolist(), - result["pred"].tolist())]) + nn.functional.sigmoid(result["pred"]).tolist())]) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): result: dict = self.forward(batch, mode="test") @@ -429,6 +409,8 @@ def train(config: omegaconf.DictConfig) -> Dict[str, float]: name=config.logger.name, root_dir=Path(config.callbacks.checkpoint.dirpath)) trainer.logger.log_metrics(metric) + else: + metric = None if config_update: # Update configuration if needed diff --git a/sweep_command.sh b/sweep_command.sh index 7ca424b..587daa7 100755 --- a/sweep_command.sh +++ b/sweep_command.sh @@ -1,9 +1,39 @@ export HYDRA_FULL_ERROR=1 -export CUDA_VISIBLE_DEVICES=1 - -python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\ - --wandb_project=ppmi\ - --config_name=train_binary.yaml\ - --sweep_prefix='Scratch' - # --overrides="['module.load_model_ckpt=meta_brain/weights/default/resnet10-42/156864-valid_mae3.465.ckpt',\ - # '+module.load_model_strict=False']" \ No newline at end of file + +read -p "Enter devices: " device +export CUDA_VISIBLE_DEVICES=$device + +read -p "Enter devices: ppmi|adni " ds +dataset=$ds + +sweep_ppmi() { + echo "Sweep on PPMI" + python sweep.py --sweep_cfg_name=ppmi_sweep.yaml\ + --wandb_project=ppmi\ + --config_name=train_binary.yaml\ + --sweep_prefix='Scratch'\ + --overrides="['dataset=ppmi_binary', \ + '+dataset.modality=[t2]', \ + 'dataloader.batch_size=4', \ + 'dataloader.num_workers=2', \ + 'trainer.accumulate_grad_batches=8']" +} + +sweep_adni() { + echo "Sweep on ADNI" + python sweep.py --sweep_cfg_name=adni_sweep.yaml\ + --wandb_project=adni\ + --config_name=train_binary.yaml\ + --sweep_prefix='Scratch'\ + --overrides="['dataset=adni']" +} + +# Check the input argument and call the appropriate function +if [ $dataset = "ppmi" ]; then + sweep_ppmi +elif [ $dataset = "adni" ]; then + sweep_adni +else + echo "Invalid argument. Usage: $0 [ppmi|adni]. Got $dataset instead" + exit 1 +fi \ No newline at end of file