diff --git a/config/train_cls.yaml b/config/train_cls.yaml index c0d930a..00fae94 100644 --- a/config/train_cls.yaml +++ b/config/train_cls.yaml @@ -8,7 +8,7 @@ defaults: - _self_ - model: convnext_cls - dataset: ppmi - - scheduler: cosine_anneal_warmup + - scheduler: exp_decay - optim: adamw dataloader: diff --git a/sage/models/base.py b/sage/models/base.py index 7bd2d9a..588aaa3 100644 --- a/sage/models/base.py +++ b/sage/models/base.py @@ -41,18 +41,18 @@ def conv_layers(self): return find_conv_modules(self.backbone) -class ClsBase(ModelBase): +class RegBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): pred = self.backbone(brain).squeeze() - loss = self.criterion(pred, age.long()) - return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long()) + loss = self.criterion(pred, age.float()) + return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu()) -class RegBase(ModelBase): +class ClsBase(ModelBase): def forward(self, brain: torch.Tensor, age: torch.Tensor): pred = self.backbone(brain).squeeze() - loss = self.criterion(pred, age.float()) - return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu()) + loss = self.criterion(pred, age.long()) + return dict(loss=loss, pred=pred.detach().cpu(), target=age.detach().cpu().long()) class ResNet(RegBase): @@ -65,12 +65,12 @@ def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) -class ResNetCls(RegBase): +class ResNetCls(ClsBase): def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name) -class ConvNextCls(RegBase): +class ConvNextCls(ClsBase): def __init__(self, backbone: nn.Module, criterion: nn.Module, name: str): super().__init__(backbone=backbone, criterion=criterion, name=name)