diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index a9fdb79..fa54594 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -78,10 +78,10 @@ def __init__(self, self.init_transforms(augmentation=augmentation) self.save_dir = Path(save_dir) self.task = task - + def setup(self, stage): self.log_brain(return_path=False) - + def init_transforms(self, augmentation: omegaconf.DictConfig): self.train_transforms = mt.Compose([ mt.Lambda(func=utils.brain2augment), @@ -89,7 +89,7 @@ def init_transforms(self, augmentation: omegaconf.DictConfig): mt.ScaleIntensity(channel_wise=True), mt.RandAdjustContrast(prob=0.1, gamma=(0.5, 2.0)), mt.RandCoarseDropout(holes=20, spatial_size=8, prob=0.4, fill_value=0.), - # mt.RandAxisFlip(prob=0.5), + mt.RandAxisFlip(prob=0.5), mt.RandZoom(prob=0.4, min_zoom=0.9, max_zoom=1.4, mode="trilinear"), mt.Lambda(func=utils.augment2brain), ]) @@ -227,12 +227,20 @@ def move_device(self, if key not in exclude_keys: result[key] = result[key].to("cpu") return result - + def log_confusion_matrix(self, result: dict): - probs = result["pred"].cpu().detach() + probs = result["pred"] + if probs.ndim == 1: + # Binary classification + probs = torch.nn.functional.sigmoid(probs) + probs = torch.stack([1-probs, probs]).T.cpu().numpy() + else: + probs = probs.cpu().detach() labels = result["target"].cpu().numpy() cf = wandb.plot.confusion_matrix(probs=probs, y_true=labels) - self.logger.experiment.log({"confusion_matrix": cf}) + roc = wandb.plot.roc_curve(y_true=labels, y_probas=probs) + pr = wandb.plot.pr_curve(y_true=labels, y_probas=probs) + self.logger.experiment.log({"confusion_matrix": cf, "roc_curve": roc, "pr_curve": pr}) def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False): output = {f"{unit}/{k}": float(v) for k, v in output.items()} @@ -262,6 +270,7 @@ def on_train_epoch_end(self): output: dict = self.train_metric.compute() self.log_result(output, unit="epoch") self.training_step_outputs.clear() + self.train_metric.reset() def validation_step(self, batch, batch_idx): result: dict = self.forward(batch, mode="valid") @@ -275,10 +284,10 @@ def on_validation_epoch_end(self): self.log_result(output, unit="epoch", prog_bar=True) result = utils._sort_outputs(outputs=self.validation_step_outputs) - if result["pred"].ndim == 2: - """ Assuming prediction with (B, C) shape is a classification task""" + if utils.check_classification(result=result): self.log_confusion_matrix(result=result) self.validation_step_outputs.clear() + self.valid_metric.reset() def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): result: dict = self.forward(batch, mode="test") @@ -287,8 +296,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): def on_predict_end(self): result = utils._sort_outputs(outputs=self.prediction_step_outputs) - if result["pred"].ndim == 2: - """ Assuming prediction with (B, C) shape is a classification task""" + if utils.check_classification(result=result): self.log_confusion_matrix(result=result) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): diff --git a/sage/trainer/utils.py b/sage/trainer/utils.py index 8424e09..ac70259 100644 --- a/sage/trainer/utils.py +++ b/sage/trainer/utils.py @@ -265,3 +265,14 @@ def augment2brain(brain: torch.Tensor) -> torch.Tensor: C = brain.shape[1] assert C == 1, f"Brain should have single-channel: #channels = {C}" return brain + + +def check_classification(result: dict) -> bool: + pred, target = result["pred"], result["target"] + if pred.ndim == 2: + return True + elif (pred.ndim == 1) & (target.unique().size(0) == 2): + # Binary case + return True + else: + return False diff --git a/sweep.py b/sweep.py index f28f5cc..684815d 100644 --- a/sweep.py +++ b/sweep.py @@ -45,8 +45,7 @@ def override_config(hydra_config: omegaconf.DictConfig, if nkeys == 1: # If no . found in key # This implies override from defaults - _subcfg = load_sweep_yaml(config_path=f"{config_path}/{key}", - config_name=f"{value}.yaml") + _subcfg = load_yaml(config_path=f"{config_path}/{key}", config_name=f"{value}.yaml") hydra_config[key] = _subcfg else: _c = hydra_config[key_list[0]] @@ -56,8 +55,10 @@ def override_config(hydra_config: omegaconf.DictConfig, else: _c[_k] = value if "sweep" in hydra_config.hydra: + # Configure directory for sweep. sweep_main_dir/subdir hydra_config.hydra.sweep.subdir = "_".join([f"{k}={v}" for k, v in update_dict.items()]) - hydra_config.callbacks.checkpoint.dirpath = hydra_config.hydra.sweep.subdir + dirpath = f"{hydra_config.hydra.sweep.dir}/{hydra_config.hydra.sweep.subdir}" + hydra_config.callbacks.checkpoint.dirpath = dirpath return hydra_config @@ -65,16 +66,12 @@ def load_default_hydra_config(config_path: str = "config", config_name: str = "train.yaml", version_base="1.1", overrides: List[str] = []) -> omegaconf.DictConfig: - """ In order to apply wandb.sweep into lightning, - we need to remove wandb.callback arguments that was used to log the original experiment. - """ with hydra.initialize(config_path=config_path, version_base=version_base): - config = hydra.compose(config_name=config_name, - overrides=overrides, return_hydra_config=True) + config = hydra.compose(config_name=config_name, overrides=overrides, return_hydra_config=True) return config -def load_sweep_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml") -> dict: +def load_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml") -> dict: with open(os.path.join(config_path, config_name), mode="r") as f: sweep_cfg = yaml.load(stream=f, Loader=yaml.FullLoader) return sweep_cfg @@ -82,7 +79,7 @@ def load_sweep_yaml(config_path: str = "config/sweep", config_name: str = "sweep def main(config: omegaconf.DictConfig, config_path: str = "config") -> float: wandb.init(project="brain-age") - print(wandb.config) + logger.info("Sweep Config: %s", wandb.config) updated_config = override_config(hydra_config=config, update_dict=wandb.config, config_path=config_path) @@ -103,7 +100,7 @@ def main(config: omegaconf.DictConfig, config_path: str = "config") -> float: func: Callable = partial(main, config=config, config_path=args.config_path) # Load wandb.sweep configuration and instantiation - sweep_cfg = load_sweep_yaml(config_path=os.path.join(args.config_path, "sweep"), - config_name=args.sweep_cfg_name) + sweep_cfg = load_yaml(config_path=os.path.join(args.config_path, "sweep"), + config_name=args.sweep_cfg_name) sweep_id = wandb.sweep(sweep=sweep_cfg, project=args.wandb_project, entity=args.wandb_entity) wandb.agent(sweep_id=sweep_id, function=func)