From 0913fe24a606a0a0ce429040509bbf0ecd3ef7b7 Mon Sep 17 00:00:00 2001 From: 1pha <1phantasmas@korea.ac.kr> Date: Tue, 5 Mar 2024 01:21:31 +0000 Subject: [PATCH] Diverge abstract class --- config/model/convnext_cls.yaml | 2 +- config/model/resnet_cls.yaml | 2 +- sage/models/base.py | 38 ++++++++++++++++++++++++++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/config/model/convnext_cls.yaml b/config/model/convnext_cls.yaml index a3fd24e..84ea951 100644 --- a/config/model/convnext_cls.yaml +++ b/config/model/convnext_cls.yaml @@ -1,4 +1,4 @@ -_target_: sage.models.base.ConvNext +_target_: sage.models.base.ConvNextCls backbone: _target_: sage.models.model_zoo.convnext.build_convnext model_name: convnext-base diff --git a/config/model/resnet_cls.yaml b/config/model/resnet_cls.yaml index 4d9885b..6d1d7d8 100644 --- a/config/model/resnet_cls.yaml +++ b/config/model/resnet_cls.yaml @@ -1,4 +1,4 @@ -_target_: sage.models.base.ResNet +_target_: sage.models.base.ResNetCls backbone: _target_: sage.models.model_zoo.resnet.build_resnet model_depth: 10 diff --git a/sage/models/base.py b/sage/models/base.py index 818fc82..35c88ce 100644 --- a/sage/models/base.py +++ b/sage/models/base.py @@ -53,7 +53,41 @@ def conv_layers(self): return find_conv_modules(self.backbone) -class ResNet(ModelBase): +class ClsBase(ModelBase): + def forward(self, brain: torch.Tensor, age: torch.Tensor): + pred = self.backbone(brain).squeeze() + loss = self.criterion(pred, age.long()) + return dict(loss=loss, + pred=pred.detach().cpu(), + target=age.detach().cpu().long()) + + +class RegBase(ModelBase): + def forward(self, brain: torch.Tensor, age: torch.Tensor): + pred = self.backbone(brain).squeeze() + loss = self.criterion(pred, age.float()) + return dict(loss=loss, + pred=pred.detach().cpu(), + target=age.detach().cpu()) + + +class ResNet(RegBase): + def __init__(self, + backbone: nn.Module, + criterion: nn.Module, + name: str): + super().__init__(backbone=backbone, criterion=criterion, name=name) + + +class ConvNext(RegBase): + def __init__(self, + backbone: nn.Module, + criterion: nn.Module, + name: str): + super().__init__(backbone=backbone, criterion=criterion, name=name) + + +class ResNetCls(RegBase): def __init__(self, backbone: nn.Module, criterion: nn.Module, @@ -61,7 +95,7 @@ def __init__(self, super().__init__(backbone=backbone, criterion=criterion, name=name) -class ConvNext(ModelBase): +class ConvNextCls(RegBase): def __init__(self, backbone: nn.Module, criterion: nn.Module,