forked from mlcommons/GaNDLF
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from szmazurek/port_lightning_corrected
Port lightning corrected
- Loading branch information
Showing
13 changed files
with
583 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -732,4 +732,5 @@ lrs | |
autograd | ||
cudagraph | ||
kwonly | ||
torchscript | ||
torchscript | ||
Stdnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
from GANDLF.losses import get_loss | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
|
||
class AbstractLossCalculator(ABC): | ||
def __init__(self, params: dict): | ||
super().__init__() | ||
self.params = params | ||
self._initialize_loss() | ||
|
||
def _initialize_loss(self): | ||
self.loss = get_loss(self.params) | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, prediction: torch.Tensor, target: torch.Tensor, *args | ||
) -> torch.Tensor: | ||
pass | ||
|
||
|
||
class LossCalculatorSDNet(AbstractLossCalculator): | ||
def __init__(self, params): | ||
super().__init__(params) | ||
self.l1_loss = get_loss(params) | ||
self.kld_loss = get_loss(params) | ||
self.mse_loss = get_loss(params) | ||
|
||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args): | ||
if len(prediction) < 2: | ||
image: torch.Tensor = args[0] | ||
loss_seg = self.loss(prediction[0], target.squeeze(-1), self.params) | ||
loss_reco = self.l1_loss(prediction[1], image[:, :1, ...], None) | ||
loss_kld = self.kld_loss(prediction[2], prediction[3]) | ||
loss_cycle = self.mse_loss(prediction[2], prediction[4], None) | ||
return 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle | ||
else: | ||
return self.loss(prediction, target, self.params) | ||
|
||
|
||
class LossCalculatorDeepSupervision(AbstractLossCalculator): | ||
def __init__(self, params): | ||
super().__init__(params) | ||
# This was taken from current Gandlf code, but I am not sure if | ||
# we should have this set rigidly here, as it enforces the number of | ||
# classes to be 4. | ||
self.loss_weights = [0.5, 0.25, 0.175, 0.075] | ||
|
||
def __call__( | ||
self, prediction: torch.Tensor, target: torch.Tensor, *args | ||
) -> torch.Tensor: | ||
if len(prediction) > 1: | ||
loss = torch.tensor(0.0, requires_grad=True) | ||
for i in range(len(prediction)): | ||
loss += ( | ||
self.loss(prediction[i], target[i], self.params) | ||
* self.loss_weights[i] | ||
) | ||
else: | ||
loss = self.loss(prediction, target, self.params) | ||
|
||
return loss | ||
|
||
|
||
class LossCalculatorSimple(AbstractLossCalculator): | ||
def __call__( | ||
self, prediction: torch.Tensor, target: torch.Tensor, *args | ||
) -> torch.Tensor: | ||
return self.loss(prediction, target, self.params) | ||
|
||
|
||
class LossCalculatorFactory: | ||
def __init__(self, params: dict): | ||
self.params = params | ||
|
||
def get_loss_calculator(self) -> AbstractLossCalculator: | ||
if self.params["model"]["architecture"] == "sdnet": | ||
return LossCalculatorSDNet(self.params) | ||
elif "deep" in self.params["model"]["architecture"].lower(): | ||
return LossCalculatorDeepSupervision(self.params) | ||
else: | ||
return LossCalculatorSimple(self.params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
from copy import deepcopy | ||
from GANDLF.metrics import get_metrics | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class AbstractMetricCalculator(ABC): | ||
def __init__(self, params: dict): | ||
super().__init__() | ||
self.params = deepcopy(params) | ||
self._initialize_metrics_dict() | ||
|
||
def _initialize_metrics_dict(self): | ||
self.metrics_calculators = get_metrics(self.params) | ||
|
||
def _process_metric_value(self, metric_value: torch.Tensor): | ||
if metric_value.dim() == 0: | ||
return metric_value.item() | ||
else: | ||
return metric_value.tolist() | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, prediction: torch.Tensor, target: torch.Tensor, *args | ||
) -> torch.Tensor: | ||
pass | ||
|
||
|
||
class MetricCalculatorSDNet(AbstractMetricCalculator): | ||
def __init__(self, params): | ||
super().__init__(params) | ||
|
||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args): | ||
metric_results = {} | ||
for metric_name, metric_calculator in self.metrics_calculators.items(): | ||
metric_value = ( | ||
metric_calculator(prediction, target, self.params).detach().cpu() | ||
) | ||
metric_results[metric_name] = self._process_metric_value(metric_value) | ||
return metric_results | ||
|
||
|
||
class MetricCalculatorDeepSupervision(AbstractMetricCalculator): | ||
def __init__(self, params): | ||
super().__init__(params) | ||
|
||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args): | ||
metric_results = {} | ||
|
||
for metric_name, metric_calculator in self.metrics_calculators.items(): | ||
metric_results[metric_name] = 0.0 | ||
for i, _ in enumerate(prediction): | ||
metric_value = ( | ||
metric_calculator(prediction[i], target[i], self.params) | ||
.detach() | ||
.cpu() | ||
) | ||
metric_results[metric_name] += self._process_metric_value(metric_value) | ||
return metric_results | ||
|
||
|
||
class MetricCalculatorSimple(AbstractMetricCalculator): | ||
def __init__(self, params): | ||
super().__init__(params) | ||
|
||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, *args): | ||
metric_results = {} | ||
|
||
for metric_name, metric_calculator in self.metrics_calculators.items(): | ||
metric_value = ( | ||
metric_calculator(prediction, target, self.params).detach().cpu() | ||
) | ||
metric_results[metric_name] = self._process_metric_value(metric_value) | ||
return metric_results | ||
|
||
|
||
class MetricCalculatorFactory: | ||
def __init__(self, params: dict): | ||
self.params = params | ||
|
||
def get_metric_calculator(self) -> AbstractMetricCalculator: | ||
if self.params["model"]["architecture"] == "sdnet": | ||
return MetricCalculatorSDNet(self.params) | ||
elif "deep" in self.params["model"]["architecture"].lower(): | ||
return MetricCalculatorDeepSupervision(self.params) | ||
else: | ||
return MetricCalculatorSimple(self.params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import lightning.pytorch as pl | ||
from GANDLF.models import get_model | ||
from GANDLF.optimizers import get_optimizer | ||
from GANDLF.schedulers import get_scheduler | ||
from GANDLF.losses.loss_calculators import LossCalculatorFactory | ||
from GANDLF.metrics.metric_calculators import MetricCalculatorFactory | ||
from GANDLF.utils.pred_target_processors import PredictionTargetProcessorFactory | ||
|
||
from copy import deepcopy | ||
|
||
|
||
class GandlfLightningModule(pl.LightningModule): | ||
def __init__(self, params: dict): | ||
super().__init__() | ||
self.params = deepcopy(params) | ||
self._initialize_model() | ||
self._initialize_loss() | ||
self._initialize_metric_calculators() | ||
self._initialize_preds_target_processor() | ||
|
||
def _initialize_model(self): | ||
self.model = get_model(self.params) | ||
|
||
def _initialize_loss(self): | ||
self.loss = LossCalculatorFactory(self.params).get_loss_calculator() | ||
|
||
def _initialize_metric_calculators(self): | ||
self.metric_calculators = MetricCalculatorFactory( | ||
self.params | ||
).get_metric_calculator() | ||
|
||
def _initialize_preds_target_processor(self): | ||
self.pred_target_processor = PredictionTargetProcessorFactory( | ||
self.params | ||
).get_prediction_target_processor() | ||
|
||
def configure_optimizers(self): | ||
params = deepcopy(self.params) | ||
params["model_parameters"] = self.model.parameters() | ||
optimizer = get_optimizer(params) | ||
if "scheduler" in self.params: | ||
params["optimizer_object"] = optimizer | ||
scheduler = get_scheduler(params) | ||
return [optimizer], [scheduler] | ||
return optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.