Skip to content

Commit

Permalink
Add confusion matrix on the end of prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Mar 5, 2024
1 parent 0913fe2 commit 14b3012
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 41 deletions.
3 changes: 2 additions & 1 deletion config/train_cls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ defaults:
- model: convnext_cls
- dataset: ppmi
- scheduler: cosine_anneal_warmup
- optim: lion
- optim: adamw

dataloader:
_target_: torch.utils.data.DataLoader
Expand All @@ -30,6 +30,7 @@ module:
spatial_size: [ 160 , 192 , 160 ]
load_from_checkpoint: # for lightning
load_model_ckpt: # For model checkpoint only
log_train_metrics: True
separate_lr:
save_dir: ${callbacks.checkpoint.dirpath}

Expand Down
49 changes: 9 additions & 40 deletions sage/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,13 @@


class ModelBase(nn.Module):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str,
task: str = "reg"):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__()
logger.info("Start Initiating model %s", name.upper())
# self.backbone = torch.compile(backbone)
self.backbone = backbone
self.criterion = criterion
self.NAME = name
self.TASK = task

def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.long())
return dict(loss=loss,
pred=pred.detach().cpu(),
target=age.detach().cpu().long())

def _forward(self, brain: torch.Tensor):
return self.backbone(brain)
Expand All @@ -45,7 +33,7 @@ def parse_ckpt(s: str):
return s
ckpt = {parse_ckpt(k): v for k, v in ckpt.items()}
self.load_state_dict(ckpt)

def conv_layers(self):
if hasattr(self.backbone, "conv_layers"):
return self.backbone.conv_layers()
Expand All @@ -57,57 +45,38 @@ class ClsBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.long())
return dict(loss=loss,
pred=pred.detach().cpu(),
target=age.detach().cpu().long())
return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long())


class RegBase(ModelBase):
def forward(self, brain: torch.Tensor, age: torch.Tensor):
pred = self.backbone(brain).squeeze()
loss = self.criterion(pred, age.float())
return dict(loss=loss,
pred=pred.detach().cpu(),
target=age.detach().cpu())
return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu())


class ResNet(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ConvNext(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ResNetCls(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class ConvNextCls(RegBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)


class SFCNModel(ModelBase):
def __init__(self,
backbone: nn.Module,
criterion: nn.Module,
name: str):
def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str):
super().__init__(backbone=backbone, criterion=criterion, name=name)
# TODO: bin_range interval and backbones' output_dim should be matched,
# but they are separately hard-coded!
Expand Down
8 changes: 8 additions & 0 deletions sage/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self,
self.log_lr = manual_lr
self.training_step_outputs = []
self.validation_step_outputs = []
self.prediction_step_outputs = []

self.init_transforms(augmentation=augmentation)
self.save_dir = Path(save_dir)
Expand Down Expand Up @@ -265,8 +266,15 @@ def on_validation_epoch_end(self):

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
result: dict = self.forward(batch, mode="test")
self.prediction_step_outputs.append(result)
return result

def on_predict_end(self):
result = utils._sort_outputs(outputs=self.validation_step_outputs)
if result["pred"].ndim == 2:
""" Assuming prediction with (B, C) shape is a classification task"""
self.log_confusion_matrix(result=result)

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
result: dict = self.forward(batch, mode="test")
return result
Expand Down

0 comments on commit 14b3012

Please sign in to comment.