diff --git a/inference.py b/inference.py old mode 100644 new mode 100755 diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index d7d4ff8..c2e3067 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -12,6 +12,7 @@ from torchmetrics import MetricCollection import wandb import monai.transforms as mt +from monai.data.meta_tensor import MetaTensor import sage import sage.xai.nilearn_plots as nilp_ @@ -179,6 +180,15 @@ def configure_scheduler(self, def configure_optimizers(self) -> torch.optim.Optimizer | dict: return self.opt_config + def get_tensor(self, tensor: torch.Tensor | MetaTensor) -> torch.Tensor: + """ monai.transforms return a internal tensor class called "Metatensor" + This datatype may throw an error for some functions from time to time. (e.g torch.compile) + Use this method to resolve the issue. """ + if isinstance(tensor, MetaTensor): + # MetaTensor is not suitable for torch.compile + tensor = tensor.as_tensor() + return tensor + def forward(self, batch, mode: str = "train"): try: """ model should return dict of @@ -199,6 +209,12 @@ def forward(self, batch, mode: str = "train"): logger.exception(e) breakpoint() raise e + + def log_confusion_matrix(self, result: dict): + probs = result["pred"].cpu().detach() + labels = result["target"].cpu().numpy() + cf = wandb.plot.confusion_matrix(probs=probs, y_true=labels) + self.logger.experiment.log({"confusion_matrix": cf}) def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False): output = {f"{unit}/{k}": float(v) for k, v in output.items()} @@ -241,8 +257,11 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): output: dict = self.valid_metric.compute() self.log_result(output, unit="epoch", prog_bar=True) + if output["pred"].ndim == 2: + """ Assuming prediction with (B, C) shape is a classification task""" + self.log_confusion_matrix(result=output) self.validation_step_outputs.clear() - + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): result: dict = self.forward(batch, mode="test") return result diff --git a/sage/xai/trainer.py b/sage/xai/trainer.py index 2eadd10..0a286a1 100644 --- a/sage/xai/trainer.py +++ b/sage/xai/trainer.py @@ -124,7 +124,7 @@ def _configure_xai(self, attr_mtd = ca.InputXGradient(forward_func=model.backbone) xai = ca.NoiseTunnel(attribution_method=attr_mtd) if xai_call_kwarg is None: - xai_call_kwarg = dict(nt_type="smoothgrad", nt_samples=10) + xai_call_kwarg = dict(nt_type="smoothgrad", nt_samples=15) else: breakpoint() self.xai_call_kwarg = dict() if xai_call_kwarg is None else xai_call_kwarg