Skip to content

Commit

Permalink
Add classification task
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 5, 2024
1 parent 407643b commit 5ceba64
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
Empty file modified inference.py
100644 → 100755
Empty file.
21 changes: 20 additions & 1 deletion sage/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchmetrics import MetricCollection
import wandb
import monai.transforms as mt
from monai.data.meta_tensor import MetaTensor

import sage
import sage.xai.nilearn_plots as nilp_
Expand Down Expand Up @@ -179,6 +180,15 @@ def configure_scheduler(self,
def configure_optimizers(self) -> torch.optim.Optimizer | dict:
return self.opt_config

def get_tensor(self, tensor: torch.Tensor | MetaTensor) -> torch.Tensor:
""" monai.transforms return a internal tensor class called "Metatensor"
This datatype may throw an error for some functions from time to time. (e.g torch.compile)
Use this method to resolve the issue. """
if isinstance(tensor, MetaTensor):
# MetaTensor is not suitable for torch.compile
tensor = tensor.as_tensor()
return tensor

def forward(self, batch, mode: str = "train"):
try:
""" model should return dict of
Expand All @@ -199,6 +209,12 @@ def forward(self, batch, mode: str = "train"):
logger.exception(e)
breakpoint()
raise e

def log_confusion_matrix(self, result: dict):
probs = result["pred"].cpu().detach()
labels = result["target"].cpu().numpy()
cf = wandb.plot.confusion_matrix(probs=probs, y_true=labels)
self.logger.experiment.log({"confusion_matrix": cf})

def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False):
output = {f"{unit}/{k}": float(v) for k, v in output.items()}
Expand Down Expand Up @@ -241,8 +257,11 @@ def validation_step(self, batch, batch_idx):
def on_validation_epoch_end(self):
output: dict = self.valid_metric.compute()
self.log_result(output, unit="epoch", prog_bar=True)
if output["pred"].ndim == 2:
""" Assuming prediction with (B, C) shape is a classification task"""
self.log_confusion_matrix(result=output)
self.validation_step_outputs.clear()

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
result: dict = self.forward(batch, mode="test")
return result
Expand Down
2 changes: 1 addition & 1 deletion sage/xai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _configure_xai(self,
attr_mtd = ca.InputXGradient(forward_func=model.backbone)
xai = ca.NoiseTunnel(attribution_method=attr_mtd)
if xai_call_kwarg is None:
xai_call_kwarg = dict(nt_type="smoothgrad", nt_samples=10)
xai_call_kwarg = dict(nt_type="smoothgrad", nt_samples=15)
else:
breakpoint()
self.xai_call_kwarg = dict() if xai_call_kwarg is None else xai_call_kwarg
Expand Down

0 comments on commit 5ceba64

Please sign in to comment.