Skip to content

Commit

Permalink
Fix sweeper subdir. Fix metrics calculation after epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 11, 2024
1 parent ae254fb commit e08ef5e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 22 deletions.
28 changes: 18 additions & 10 deletions sage/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,18 @@ def __init__(self,
self.init_transforms(augmentation=augmentation)
self.save_dir = Path(save_dir)
self.task = task

def setup(self, stage):
self.log_brain(return_path=False)

def init_transforms(self, augmentation: omegaconf.DictConfig):
self.train_transforms = mt.Compose([
mt.Lambda(func=utils.brain2augment),
mt.Resize(spatial_size=augmentation.get("spatial_size", C.SPATIAL_SIZE)),
mt.ScaleIntensity(channel_wise=True),
mt.RandAdjustContrast(prob=0.1, gamma=(0.5, 2.0)),
mt.RandCoarseDropout(holes=20, spatial_size=8, prob=0.4, fill_value=0.),
# mt.RandAxisFlip(prob=0.5),
mt.RandAxisFlip(prob=0.5),
mt.RandZoom(prob=0.4, min_zoom=0.9, max_zoom=1.4, mode="trilinear"),
mt.Lambda(func=utils.augment2brain),
])
Expand Down Expand Up @@ -227,12 +227,20 @@ def move_device(self,
if key not in exclude_keys:
result[key] = result[key].to("cpu")
return result

def log_confusion_matrix(self, result: dict):
probs = result["pred"].cpu().detach()
probs = result["pred"]
if probs.ndim == 1:
# Binary classification
probs = torch.nn.functional.sigmoid(probs)
probs = torch.stack([1-probs, probs]).T.cpu().numpy()
else:
probs = probs.cpu().detach()
labels = result["target"].cpu().numpy()
cf = wandb.plot.confusion_matrix(probs=probs, y_true=labels)
self.logger.experiment.log({"confusion_matrix": cf})
roc = wandb.plot.roc_curve(y_true=labels, y_probas=probs)
pr = wandb.plot.pr_curve(y_true=labels, y_probas=probs)
self.logger.experiment.log({"confusion_matrix": cf, "roc_curve": roc, "pr_curve": pr})

def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False):
output = {f"{unit}/{k}": float(v) for k, v in output.items()}
Expand Down Expand Up @@ -262,6 +270,7 @@ def on_train_epoch_end(self):
output: dict = self.train_metric.compute()
self.log_result(output, unit="epoch")
self.training_step_outputs.clear()
self.train_metric.reset()

def validation_step(self, batch, batch_idx):
result: dict = self.forward(batch, mode="valid")
Expand All @@ -275,10 +284,10 @@ def on_validation_epoch_end(self):
self.log_result(output, unit="epoch", prog_bar=True)

result = utils._sort_outputs(outputs=self.validation_step_outputs)
if result["pred"].ndim == 2:
""" Assuming prediction with (B, C) shape is a classification task"""
if utils.check_classification(result=result):
self.log_confusion_matrix(result=result)
self.validation_step_outputs.clear()
self.valid_metric.reset()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
result: dict = self.forward(batch, mode="test")
Expand All @@ -287,8 +296,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):

def on_predict_end(self):
result = utils._sort_outputs(outputs=self.prediction_step_outputs)
if result["pred"].ndim == 2:
""" Assuming prediction with (B, C) shape is a classification task"""
if utils.check_classification(result=result):
self.log_confusion_matrix(result=result)

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
Expand Down
11 changes: 11 additions & 0 deletions sage/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,14 @@ def augment2brain(brain: torch.Tensor) -> torch.Tensor:
C = brain.shape[1]
assert C == 1, f"Brain should have single-channel: #channels = {C}"
return brain


def check_classification(result: dict) -> bool:
pred, target = result["pred"], result["target"]
if pred.ndim == 2:
return True
elif (pred.ndim == 1) & (target.unique().size(0) == 2):
# Binary case
return True
else:
return False
21 changes: 9 additions & 12 deletions sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def override_config(hydra_config: omegaconf.DictConfig,
if nkeys == 1:
# If no . found in key
# This implies override from defaults
_subcfg = load_sweep_yaml(config_path=f"{config_path}/{key}",
config_name=f"{value}.yaml")
_subcfg = load_yaml(config_path=f"{config_path}/{key}", config_name=f"{value}.yaml")
hydra_config[key] = _subcfg
else:
_c = hydra_config[key_list[0]]
Expand All @@ -56,33 +55,31 @@ def override_config(hydra_config: omegaconf.DictConfig,
else:
_c[_k] = value
if "sweep" in hydra_config.hydra:
# Configure directory for sweep. sweep_main_dir/subdir
hydra_config.hydra.sweep.subdir = "_".join([f"{k}={v}" for k, v in update_dict.items()])
hydra_config.callbacks.checkpoint.dirpath = hydra_config.hydra.sweep.subdir
dirpath = f"{hydra_config.hydra.sweep.dir}/{hydra_config.hydra.sweep.subdir}"
hydra_config.callbacks.checkpoint.dirpath = dirpath
return hydra_config


def load_default_hydra_config(config_path: str = "config",
config_name: str = "train.yaml",
version_base="1.1",
overrides: List[str] = []) -> omegaconf.DictConfig:
""" In order to apply wandb.sweep into lightning,
we need to remove wandb.callback arguments that was used to log the original experiment.
"""
with hydra.initialize(config_path=config_path, version_base=version_base):
config = hydra.compose(config_name=config_name,
overrides=overrides, return_hydra_config=True)
config = hydra.compose(config_name=config_name, overrides=overrides, return_hydra_config=True)
return config


def load_sweep_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml") -> dict:
def load_yaml(config_path: str = "config/sweep", config_name: str = "sweep.yaml") -> dict:
with open(os.path.join(config_path, config_name), mode="r") as f:
sweep_cfg = yaml.load(stream=f, Loader=yaml.FullLoader)
return sweep_cfg


def main(config: omegaconf.DictConfig, config_path: str = "config") -> float:
wandb.init(project="brain-age")
print(wandb.config)
logger.info("Sweep Config: %s", wandb.config)
updated_config = override_config(hydra_config=config,
update_dict=wandb.config,
config_path=config_path)
Expand All @@ -103,7 +100,7 @@ def main(config: omegaconf.DictConfig, config_path: str = "config") -> float:
func: Callable = partial(main, config=config, config_path=args.config_path)

# Load wandb.sweep configuration and instantiation
sweep_cfg = load_sweep_yaml(config_path=os.path.join(args.config_path, "sweep"),
config_name=args.sweep_cfg_name)
sweep_cfg = load_yaml(config_path=os.path.join(args.config_path, "sweep"),
config_name=args.sweep_cfg_name)
sweep_id = wandb.sweep(sweep=sweep_cfg, project=args.wandb_project, entity=args.wandb_entity)
wandb.agent(sweep_id=sweep_id, function=func)

0 comments on commit e08ef5e

Please sign in to comment.