Skip to content

Commit

Permalink
Merge pull request #43 from szmazurek/port_lightning_corrected
Browse files Browse the repository at this point in the history
Port lightning corrected
  • Loading branch information
szmazurek authored Dec 7, 2024
2 parents 53d3849 + bff4381 commit 90c98cd
Show file tree
Hide file tree
Showing 13 changed files with 583 additions and 17 deletions.
3 changes: 2 additions & 1 deletion .spelling/.spelling/expect.txt
Original file line number Diff line number Diff line change
Expand Up @@ -732,4 +732,5 @@ lrs
autograd
cudagraph
kwonly
torchscript
torchscript
Stdnet
4 changes: 2 additions & 2 deletions GANDLF/compute/loss_and_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_metric_output(
if len(temp) > 1:
return temp
else:
# TODO: this branch is extremely age case and is buggy.
# TODO: this branch is extremely edge case and is buggy.
# Overall the case when metric returns a list but of length 1 is very rare. The only case is when
# the metric returns Nx.. tensor (i.e. without aggregation by elements) and batch_size==N==1. This branch
# would definitely fail for such a metrics like
Expand Down Expand Up @@ -134,7 +134,7 @@ def get_loss_and_metrics(
# Metrics should be a list
for metric in params["metrics"]:
metric_lower = metric.lower()
metric_output[metric] = 0
metric_output[metric] = 0.0
if metric_lower not in global_metrics_dict:
warnings.warn("WARNING: Could not find the requested metric '" + metric)
continue
Expand Down
24 changes: 23 additions & 1 deletion GANDLF/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .regression import CE, CEL, MSE_loss, L1_loss
from .hybrid import DCCE, DCCE_Logits, DC_Focal


# global defines for the losses
global_losses_dict = {
"dc": MCD_loss,
Expand All @@ -38,3 +37,26 @@
"focal": FocalLoss,
"dc_focal": DC_Focal,
}


def get_loss(params: dict) -> object:
"""
Function to get the loss definition.
Args:
params (dict): The parameters' dictionary.
Returns:
loss (object): The loss definition.
"""
# TODO This check looks like legacy code, should we have it?

if isinstance(params["loss_function"], dict):
chosen_loss = list(params["loss_function"].keys())[0].lower()
else:
chosen_loss = params["loss_function"].lower()
assert (
chosen_loss in global_losses_dict
), f"Could not find the requested loss function '{params['loss_function']}'"

return global_losses_dict[chosen_loss]
83 changes: 83 additions & 0 deletions GANDLF/losses/loss_calculators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from GANDLF.losses import get_loss
from abc import ABC, abstractmethod
from typing import List


class AbstractLossCalculator(ABC):
def __init__(self, params: dict):
super().__init__()
self.params = params
self._initialize_loss()

def _initialize_loss(self):
self.loss = get_loss(self.params)

@abstractmethod
def __call__(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
pass


class LossCalculatorSDNet(AbstractLossCalculator):
def __init__(self, params):
super().__init__(params)
self.l1_loss = get_loss(params)
self.kld_loss = get_loss(params)
self.mse_loss = get_loss(params)

def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
if len(prediction) < 2:
image: torch.Tensor = args[0]
loss_seg = self.loss(prediction[0], target.squeeze(-1), self.params)
loss_reco = self.l1_loss(prediction[1], image[:, :1, ...], None)
loss_kld = self.kld_loss(prediction[2], prediction[3])
loss_cycle = self.mse_loss(prediction[2], prediction[4], None)
return 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle
else:
return self.loss(prediction, target, self.params)


class LossCalculatorDeepSupervision(AbstractLossCalculator):
def __init__(self, params):
super().__init__(params)
# This was taken from current Gandlf code, but I am not sure if
# we should have this set rigidly here, as it enforces the number of
# classes to be 4.
self.loss_weights = [0.5, 0.25, 0.175, 0.075]

def __call__(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
if len(prediction) > 1:
loss = torch.tensor(0.0, requires_grad=True)
for i in range(len(prediction)):
loss += (
self.loss(prediction[i], target[i], self.params)
* self.loss_weights[i]
)
else:
loss = self.loss(prediction, target, self.params)

return loss


class LossCalculatorSimple(AbstractLossCalculator):
def __call__(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
return self.loss(prediction, target, self.params)


class LossCalculatorFactory:
def __init__(self, params: dict):
self.params = params

def get_loss_calculator(self) -> AbstractLossCalculator:
if self.params["model"]["architecture"] == "sdnet":
return LossCalculatorSDNet(self.params)
elif "deep" in self.params["model"]["architecture"].lower():
return LossCalculatorDeepSupervision(self.params)
else:
return LossCalculatorSimple(self.params)
25 changes: 25 additions & 0 deletions GANDLF/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
All the metrics are to be called from here
"""
from warnings import warn
from typing import Union

from GANDLF.losses.regression import MSE_loss, CEL
Expand Down Expand Up @@ -102,6 +103,30 @@
]


def get_metrics(params: dict) -> dict:
"""
Returns an dictionary of containing calculators of the specified metric functions
Args:
params (dict): A dictionary containing the overall training parameters.
Returns:
metric_calculators (dict): A dictionary containing the calculators of the specified metric functions.
"""
metric_calculators = {}
for metric_name in params["metrics"]:
metric_name = metric_name.lower()
if metric_name not in global_metrics_dict:
warn(
f"Metric {metric_name} not found in global metrics dictionary, it will not be used.",
UserWarning,
)
continue
else:
metric_calculators[metric_name] = global_metrics_dict[metric_name]
return metric_calculators


def overall_stats(predictions, ground_truth, params) -> dict[str, Union[float, list]]:
"""
Generates a dictionary of metrics calculated on the overall predictions and ground truths.
Expand Down
87 changes: 87 additions & 0 deletions GANDLF/metrics/metric_calculators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from copy import deepcopy
from GANDLF.metrics import get_metrics
from abc import ABC, abstractmethod


class AbstractMetricCalculator(ABC):
def __init__(self, params: dict):
super().__init__()
self.params = deepcopy(params)
self._initialize_metrics_dict()

def _initialize_metrics_dict(self):
self.metrics_calculators = get_metrics(self.params)

def _process_metric_value(self, metric_value: torch.Tensor):
if metric_value.dim() == 0:
return metric_value.item()
else:
return metric_value.tolist()

@abstractmethod
def __call__(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
pass


class MetricCalculatorSDNet(AbstractMetricCalculator):
def __init__(self, params):
super().__init__(params)

def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
metric_results = {}
for metric_name, metric_calculator in self.metrics_calculators.items():
metric_value = (
metric_calculator(prediction, target, self.params).detach().cpu()
)
metric_results[metric_name] = self._process_metric_value(metric_value)
return metric_results


class MetricCalculatorDeepSupervision(AbstractMetricCalculator):
def __init__(self, params):
super().__init__(params)

def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
metric_results = {}

for metric_name, metric_calculator in self.metrics_calculators.items():
metric_results[metric_name] = 0.0
for i, _ in enumerate(prediction):
metric_value = (
metric_calculator(prediction[i], target[i], self.params)
.detach()
.cpu()
)
metric_results[metric_name] += self._process_metric_value(metric_value)
return metric_results


class MetricCalculatorSimple(AbstractMetricCalculator):
def __init__(self, params):
super().__init__(params)

def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args):
metric_results = {}

for metric_name, metric_calculator in self.metrics_calculators.items():
metric_value = (
metric_calculator(prediction, target, self.params).detach().cpu()
)
metric_results[metric_name] = self._process_metric_value(metric_value)
return metric_results


class MetricCalculatorFactory:
def __init__(self, params: dict):
self.params = params

def get_metric_calculator(self) -> AbstractMetricCalculator:
if self.params["model"]["architecture"] == "sdnet":
return MetricCalculatorSDNet(self.params)
elif "deep" in self.params["model"]["architecture"].lower():
return MetricCalculatorDeepSupervision(self.params)
else:
return MetricCalculatorSimple(self.params)
11 changes: 8 additions & 3 deletions GANDLF/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .brain_age import brainage
from .unetr import unetr
from .transunet import transunet
from .modelBase import ModelBase

# Define a dictionary of model architectures and corresponding functions
global_models_dict = {
Expand Down Expand Up @@ -110,14 +111,18 @@
}


def get_model(params):
def get_model(params: dict) -> ModelBase:
"""
Function to get the model definition.
Args:
params (dict): The parameters' dictionary.
Returns:
model (torch.nn.Module): The model definition.
model (ModelBase): The model definition.
"""
return global_models_dict[params["model"]["architecture"]](parameters=params)
chosen_model = params["model"]["architecture"].lower()
assert (
chosen_model in global_models_dict
), f"Could not find the requested model '{params['model']['architecture']}'"
return global_models_dict[chosen_model](parameters=params)
45 changes: 45 additions & 0 deletions GANDLF/models/lightning_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import lightning.pytorch as pl
from GANDLF.models import get_model
from GANDLF.optimizers import get_optimizer
from GANDLF.schedulers import get_scheduler
from GANDLF.losses.loss_calculators import LossCalculatorFactory
from GANDLF.metrics.metric_calculators import MetricCalculatorFactory
from GANDLF.utils.pred_target_processors import PredictionTargetProcessorFactory

from copy import deepcopy


class GandlfLightningModule(pl.LightningModule):
def __init__(self, params: dict):
super().__init__()
self.params = deepcopy(params)
self._initialize_model()
self._initialize_loss()
self._initialize_metric_calculators()
self._initialize_preds_target_processor()

def _initialize_model(self):
self.model = get_model(self.params)

def _initialize_loss(self):
self.loss = LossCalculatorFactory(self.params).get_loss_calculator()

def _initialize_metric_calculators(self):
self.metric_calculators = MetricCalculatorFactory(
self.params
).get_metric_calculator()

def _initialize_preds_target_processor(self):
self.pred_target_processor = PredictionTargetProcessorFactory(
self.params
).get_prediction_target_processor()

def configure_optimizers(self):
params = deepcopy(self.params)
params["model_parameters"] = self.model.parameters()
optimizer = get_optimizer(params)
if "scheduler" in self.params:
params["optimizer_object"] = optimizer
scheduler = get_scheduler(params)
return [optimizer], [scheduler]
return optimizer
12 changes: 4 additions & 8 deletions GANDLF/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,9 @@ def get_optimizer(params):
optimizer (torch.optim.Optimizer): An instance of the specified optimizer.
"""
# Retrieve the optimizer type from the input parameters
optimizer_type = params["optimizer"]["type"]

chosen_optimizer = params["optimizer"]["type"]
assert (
optimizer_type in global_optimizer_dict
), f"Optimizer type {optimizer_type} not found"

# Create the optimizer instance using the specified type and input parameters
optimizer_function = global_optimizer_dict[optimizer_type]
return optimizer_function(params)
chosen_optimizer in global_optimizer_dict
), f"Could not find the requested optimizer '{params['optimizer']['type']}'"
return global_optimizer_dict[chosen_optimizer](params)
8 changes: 6 additions & 2 deletions GANDLF/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def get_scheduler(params):
params (dict): The parameters' dictionary.
Returns:
model (object): The scheduler definition.
scheduler (object): The scheduler definition.
"""
return global_schedulers_dict[params["scheduler"]["type"]](params)
chosen_scheduler = params["scheduler"]["type"].lower()
assert (
chosen_scheduler in global_schedulers_dict
), f"Could not find the requested scheduler '{params['scheduler']['type']}'"
return global_schedulers_dict[chosen_scheduler](params)
Loading

0 comments on commit 90c98cd

Please sign in to comment.