Skip to content

Commit

Permalink
Merge pull request #24 from szmazurek/feature/improving_metrics_synth…
Browse files Browse the repository at this point in the history
…esis

Feature/improving metrics synthesis
  • Loading branch information
szmazurek authored Feb 23, 2024
2 parents 56d577c + 804a3fe commit be16be5
Show file tree
Hide file tree
Showing 15 changed files with 733 additions and 452 deletions.
92 changes: 63 additions & 29 deletions GANDLF/GAN/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from GANDLF.metrics import overall_stats
from tqdm import tqdm
from typing import Union, Tuple
from typing import Union, Tuple, Optional
from GANDLF.models.modelBase import ModelBase
from .generic import get_fixed_latent_vector

Expand Down Expand Up @@ -95,7 +95,6 @@ def validate_network_gan(
# Set the model to valid
if params["model"]["type"] == "torch":
model.eval()
FIXED_LATENT_VECTOR = get_fixed_latent_vector(params, mode)

for batch_idx, (subject) in enumerate(
tqdm(dataloader, desc="Looping over " + mode + " data")
Expand Down Expand Up @@ -183,21 +182,32 @@ def validate_network_gan(
if (
batch_idx == 0
): # genereate the fake images only ONCE, as they are fixed
fake_images = model.generator(FIXED_LATENT_VECTOR)
original_batch_size = params["batch_size"]
params[
"batch_size"
] = 1 # set this for patch-wise inference
fake_images = model.generator(
get_fixed_latent_vector(params, mode)
)
params[
"batch_size"
] = original_batch_size # restore the param
loss_fake, _, output_disc_fake, _ = step_gan(
model,
fake_images,
label_fake,
params,
secondary_images=None,
)

loss_disc_real, metrics_output, output_disc_real, _ = step_gan(
model,
image,
label_real,
params,
secondary_images=fake_images[: current_batch_size - 1],
secondary_images=fake_images,
)

for metric in params["metrics"]:
# average over all patches for the current subject
total_epoch_metrics[metric] += metrics_output[metric] / len(
Expand Down Expand Up @@ -233,32 +243,54 @@ def validate_network_gan(
tensor=subject["1"]["data"].squeeze(0),
affine=subject["1"]["affine"].squeeze(0),
).as_sitk()
fake_images_batch = fake_images.cpu().numpy()
# generate ENTIRE batch of fake

# perform postprocessing before reverse one-hot encoding here

# if jpg detected, convert to 8-bit arrays
ext = get_filename_extension_sanitized(subject["1"]["path"][0])
if ext in [
".jpg",
".jpeg",
".png",
]:
fake_images_batch = fake_images_batch.astype(np.uint8)
# ext = get_filename_extension_sanitized(subject["1"]["path"][0])
### TODO do not use this, temporary only for debugging
ext = ".png"
# if ext in [
# ".jpg",
# ".jpeg",
# ".png",
# ]:
# fake_images_batch = fake_images_batch.astype(np.uint8)
with torch.no_grad():
fake_images_to_save = (
model.generator(get_fixed_latent_vector(params, mode)).cpu()
# .numpy()
) # generate fake batch for saving

## special case for 2D
if image.shape[-1] > 1:
result_image = sitk.GetImageFromArray(fake_images_batch)
else:
result_image = sitk.GetImageFromArray(fake_images_batch.squeeze(0))
# result_image.CopyInformation(img_for_metadata)
# for i, fake_image_to_save in enumerate(fake_images_to_save[:16]):
# fake_image_to_save = ((fake_image_to_save + 1) * (255 / 2)).astype(
# np.uint8
# )
# # if ext in [
# # ".jpg",
# # ".jpeg",
# # ".png",
# # ]:
# # fake_image_to_save = fake_image_to_save.astype(np.uint8)
# print(fake_image_to_save.shape)
# print(np.min(fake_image_to_save), np.max(fake_image_to_save))
# if image.shape[-1] > 1:
# result_image = sitk.GetImageFromArray(fake_image_to_save)
# else:
# result_image = sitk.GetImageFromArray(
# fake_image_to_save.squeeze()
# )
# # result_image.CopyInformation(img_for_metadata)

# this handles cases that need resampling/resizing
if "resample" in params["data_preprocessing"]:
result_image = resample_image(
result_image,
img_for_metadata.GetSpacing(),
interpolator=sitk.sitkNearestNeighbor,
)
# # this handles cases that need resampling/resizing
# if "resample" in params["data_preprocessing"]:
# result_image = resample_image(
# result_image,
# img_for_metadata.GetSpacing(),
# interpolator=sitk.sitkNearestNeighbor,
# )
# Create the subject directory if it doesn't exist in the
# current_output_dir directory
os.makedirs(
Expand All @@ -278,13 +310,15 @@ def validate_network_gan(
current_output_dir,
"testing",
subject["subject_id"][0],
subject["subject_id"][0] + "_gen" + ext,
)
sitk.WriteImage(
result_image,
path_to_save,
subject["subject_id"][0] + f"_gen" + ext,
)
# sitk.WriteImage(
# result_image,
# path_to_save,
# )
import torchvision.utils as vutils

vutils.save_image(fake_images_to_save, path_to_save, normalize=True)
if scheduler_d is not None:
if params["scheduler_d"]["type"] in [
"reduce_on_plateau",
Expand Down
1 change: 0 additions & 1 deletion GANDLF/GAN/compute/inference_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def inference_loop_gans(
generated_images = model(latent_vector)
generated_images = generated_images.cpu().to(torch.uint8)
for i in range(generated_images.shape[0]):

if parameters["model"]["dimension"] == 2:
image_to_save = (
generated_images[i].permute(1, 2, 0).numpy()
Expand Down
1 change: 0 additions & 1 deletion GANDLF/GAN/compute/loss_and_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def get_metric_output(metric_function, predicted, ground_truth, params):
metric_output = (
metric_function(predicted, ground_truth, params).detach().cpu()
)

if metric_output.dim() == 0:
return metric_output.item()
else:
Expand Down
25 changes: 17 additions & 8 deletions GANDLF/GAN/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
version_check,
write_training_patches,
print_model_summary,
get_ground_truths_and_predictions_tensor,
get_model_dict,
print_and_format_metrics,
)
from GANDLF.metrics import overall_stats
from GANDLF.logger import LoggerGAN
from .step import step_gan
from .forward_pass import validate_network_gan
Expand Down Expand Up @@ -93,8 +91,6 @@ def train_network_gan(
):
#### DISCRIMINATOR STEP WITH ALL REAL LABELS ####
optimizer_d.zero_grad()
optimizer_g.zero_grad()

image_real = (
torch.cat(
[subject[key][torchio.DATA] for key in params["channel_keys"]],
Expand Down Expand Up @@ -166,7 +162,7 @@ def train_network_gan(
fake_images = model.generator(latent_vector)
loss_disc_fake, _, output_disc_fake, _ = step_gan(
model,
fake_images,
fake_images.detach(),
label_fake,
params,
secondary_images=None,
Expand Down Expand Up @@ -203,18 +199,17 @@ def train_network_gan(
if not nan_loss:
total_epoch_train_loss_disc += loss_disc.detach().cpu().item()
optimizer_d.step()
optimizer_d.zero_grad()
### GENERATOR STEP ###
optimizer_g.zero_grad()
label_fake = label_real.fill_(1)
# TODO should we really use THE SAME fake images?
loss_gen, calculated_metrics, output_gen_step, _ = step_gan(
model,
fake_images.detach(),
fake_images,
label_fake,
params,
secondary_images=image_real,
)

nan_loss = torch.isnan(loss_gen)
second_order = (
hasattr(optimizer_g, "is_second_order")
Expand Down Expand Up @@ -249,6 +244,19 @@ def train_network_gan(
optimizer_g.zero_grad()
if not nan_loss:
total_epoch_train_loss_gen += loss_gen.detach().cpu().item()
for metric in calculated_metrics.keys():
if isinstance(total_epoch_train_metric[metric], list):
if len(total_epoch_train_metric[metric]) == 0:
total_epoch_train_metric[metric] = np.array(
calculated_metrics[metric]
)
else:
total_epoch_train_metric[metric] += np.array(
calculated_metrics[metric]
)
else:
total_epoch_train_metric[metric] += calculated_metrics[metric]

average_epoch_train_loss_gen = total_epoch_train_loss_gen / len(
train_dataloader
)
Expand Down Expand Up @@ -383,6 +391,7 @@ def training_loop_gans(
) = create_pytorch_objects_gan(
params, training_data, validation_data, device
)
print(f"Train dataloader length: {len(train_dataloader)}")
# save the initial model
if not os.path.exists(model_paths["initial"]):
# TODO check if the saving is indeed correct
Expand Down
1 change: 1 addition & 0 deletions GANDLF/GAN/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .generation import SSIM, FID, LPIPS
67 changes: 33 additions & 34 deletions GANDLF/GAN/metrics/gan_utils/functional/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def _upsample(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor:

def _normalize_tensor(in_feat: Tensor, eps: float = 1e-8) -> Tensor:
"""Normalize input tensor."""
norm_factor = torch.sqrt(
eps + torch.sum(in_feat**2, dim=1, keepdim=True)
)
norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / norm_factor


Expand Down Expand Up @@ -117,7 +115,6 @@ def __init__(

self.n_dim = n_dim

@torch.no_grad
def forward(
self,
in0: Tensor,
Expand All @@ -138,32 +135,35 @@ def forward(
if self.resize is not None:
in0_input = _resize_tensor(in0_input, size=self.resize)
in1_input = _resize_tensor(in1_input, size=self.resize)

outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}

for kk in range(self.L):
feats0[kk], feats1[kk] = _normalize_tensor(
outs0[kk]
), _normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = []
for kk in range(self.L):
if self.spatial:
res.append(
_upsample(
self.lins[kk](diffs[kk]), out_hw=tuple(in0.shape[2:])
with torch.no_grad():
outs0, outs1 = self.net.forward(in0_input), self.net.forward(
in1_input
)
feats0, feats1, diffs = {}, {}, {}

for kk in range(self.L):
feats0[kk], feats1[kk] = _normalize_tensor(
outs0[kk]
), _normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = []
for kk in range(self.L):
if self.spatial:
res.append(
_upsample(
self.lins[kk](diffs[kk]),
out_hw=tuple(in0.shape[2:]),
)
)
)
else:
res.append(
_spatial_average(
self.lins[kk](diffs[kk]),
n_dims=self.n_dim,
keep_dim=True,
else:
res.append(
_spatial_average(
self.lins[kk](diffs[kk]),
n_dims=self.n_dim,
keep_dim=True,
)
)
)
val: Tensor = sum(res) # type: ignore[assignment]
if retperlayer:
return (val, res)
Expand Down Expand Up @@ -237,8 +237,8 @@ def lpips_compute(


def determine_converter(
converter_type: Union[str, None]
) -> Union[None, SoftACSConverter, ACSConverter, Conv3dConverter]:
converter_type: Literal["soft", "acs", "conv3d"] = "soft",
) -> Union[SoftACSConverter, ACSConverter, Conv3dConverter]:
"""Determine the converter type to use for 2D to 3D conversion.
Args:
converter_type: str indicating the type of converter to use for
Expand All @@ -256,7 +256,7 @@ def determine_converter(
elif converter_type == "conv3d":
converter = Conv3dConverter
else:
raise ValueError(f"Unknown converter type {converter}")
raise ValueError(f"Unknown converter type {converter_type}")
return converter


Expand Down Expand Up @@ -315,7 +315,7 @@ def learned_perceptual_image_patch_similarity(
normalize: bool = False,
n_dim: int = 2,
n_channels: int = 1,
converter_type: Union[str, None] = None,
converter_type: Literal["soft", "acs", "conv3d"] = "soft",
) -> Tensor:
"""Functional Interface for The Learned Perceptual Image Patch Similarity
(`LPIPS_`) calculates perceptual similarity between two images.
Expand All @@ -341,8 +341,7 @@ def learned_perceptual_image_patch_similarity(
expect input to be in the ``[0,1]`` range.
converter_type: str indicating the type of converter to use for
converting the net into a 3D network if the input is 5D. Choose
between `'soft'`, `'acs'`, `'conv3d'`. If ``None`` will use
`'soft'` by default.
between `'soft'`, `'acs'`, `'conv3d'`. Will use `'soft'` by default.
Example:
>>> import torch
Expand Down
10 changes: 7 additions & 3 deletions GANDLF/GAN/metrics/gan_utils/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
normalize: bool = False,
n_dim: int = 2,
n_channels: int = 1,
converter_type: Union[str, None] = None,
converter_type: Literal["soft", "acs", "conv3d"] = "soft",
**kwargs: Any,
):
"""Initialize the LPIPS metric for GanDLF. This metric is based on the
Expand All @@ -53,8 +53,8 @@ def __init__(
normalize (bool): Whether to normalize the input images.
n_dim (int): The number of dimensions of the input images.
n_channels (int): The number of channels of the input images.
converter_type (Union[str, None]): The converter type from ACS, one of
'soft','asc' or 'conv3d'. If None, defaults to 'soft'.
converter_type (Literal["soft","acs", "conv3d]: The converter type
from ACS, one of 'soft','acs' or 'conv3d'. Defaults to 'soft'.
**kwargs: Additional arguments for the metric.
"""

Expand Down Expand Up @@ -145,3 +145,7 @@ def plot(
"""
return self._plot(val, ax)


if __name__ == "__main__":
calc = LPIPSGandlf()
Loading

0 comments on commit be16be5

Please sign in to comment.