diff --git a/config/train_cls.yaml b/config/train_cls.yaml index 764961b..c0d930a 100644 --- a/config/train_cls.yaml +++ b/config/train_cls.yaml @@ -9,7 +9,7 @@ defaults: - model: convnext_cls - dataset: ppmi - scheduler: cosine_anneal_warmup - - optim: lion + - optim: adamw dataloader: _target_: torch.utils.data.DataLoader @@ -30,6 +30,7 @@ module: spatial_size: [ 160 , 192 , 160 ] load_from_checkpoint: # for lightning load_model_ckpt: # For model checkpoint only + log_train_metrics: True separate_lr: save_dir: ${callbacks.checkpoint.dirpath} diff --git a/sage/models/base.py b/sage/models/base.py index 35c88ce..7bd2d9a 100644 --- a/sage/models/base.py +++ b/sage/models/base.py @@ -10,25 +10,13 @@ class ModelBase(nn.Module): - def __init__(self, - backbone: nn.Module, - criterion: nn.Module, - name: str, - task: str = "reg"): + def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__() logger.info("Start Initiating model %s", name.upper()) # self.backbone = torch.compile(backbone) self.backbone = backbone self.criterion = criterion self.NAME = name - self.TASK = task - - def forward(self, brain: torch.Tensor, age: torch.Tensor): - pred = self.backbone(brain).squeeze() - loss = self.criterion(pred, age.long()) - return dict(loss=loss, - pred=pred.detach().cpu(), - target=age.detach().cpu().long()) def _forward(self, brain: torch.Tensor): return self.backbone(brain) @@ -45,7 +33,7 @@ def parse_ckpt(s: str): return s ckpt = {parse_ckpt(k): v for k, v in ckpt.items()} self.load_state_dict(ckpt) - + def conv_layers(self): if hasattr(self.backbone, "conv_layers"): return self.backbone.conv_layers() @@ -57,57 +45,38 @@ class ClsBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): pred = self.backbone(brain).squeeze() loss = self.criterion(pred, age.long()) - return dict(loss=loss, - pred=pred.detach().cpu(), - target=age.detach().cpu().long()) + return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long()) class RegBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): pred = self.backbone(brain).squeeze() loss = self.criterion(pred, age.float()) - return dict(loss=loss, - pred=pred.detach().cpu(), - target=age.detach().cpu()) + return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu()) class ResNet(RegBase): - def __init__(self, - backbone: nn.Module, - criterion: nn.Module, - name: str): + def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) class ConvNext(RegBase): - def __init__(self, - backbone: nn.Module, - criterion: nn.Module, - name: str): + def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) class ResNetCls(RegBase): - def __init__(self, - backbone: nn.Module, - criterion: nn.Module, - name: str): + def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) class ConvNextCls(RegBase): - def __init__(self, - backbone: nn.Module, - criterion: nn.Module, - name: str): + def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) class SFCNModel(ModelBase): - def __init__(self, - backbone: nn.Module, - criterion: nn.Module, - name: str): + def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) # TODO: bin_range interval and backbones' output_dim should be matched, # but they are separately hard-coded! diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index f3a7126..968048a 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -73,6 +73,7 @@ def __init__(self, self.log_lr = manual_lr self.training_step_outputs = [] self.validation_step_outputs = [] + self.prediction_step_outputs = [] self.init_transforms(augmentation=augmentation) self.save_dir = Path(save_dir) @@ -265,8 +266,15 @@ def on_validation_epoch_end(self): def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): result: dict = self.forward(batch, mode="test") + self.prediction_step_outputs.append(result) return result + def on_predict_end(self): + result = utils._sort_outputs(outputs=self.validation_step_outputs) + if result["pred"].ndim == 2: + """ Assuming prediction with (B, C) shape is a classification task""" + self.log_confusion_matrix(result=result) + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): result: dict = self.forward(batch, mode="test") return result