diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 54c3ac7b..43605e71 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -17,6 +17,7 @@ requirements: build: - python >=3.9 - pip + - setuptools run: - python >=3.9 diff --git a/pytorch3dunet/augment/transforms.py b/pytorch3dunet/augment/transforms.py index 527d596b..ec1abb24 100644 --- a/pytorch3dunet/augment/transforms.py +++ b/pytorch3dunet/augment/transforms.py @@ -4,7 +4,7 @@ import numpy as np import torch from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve -from skimage import measure +from skimage import measure, exposure from skimage.filters import gaussian from skimage.segmentation import find_boundaries @@ -133,6 +133,27 @@ def __call__(self, m): return m +class RandomGammaCorrection: + """ + Adjust contrast by scaling each voxel to `v ** gamma`. + """ + + def __init__(self, random_state, gamma=(0.5, 1.5), execution_probability=0.1, **kwargs): + self.random_state = random_state + assert len(gamma) == 2 + self.gamma = gamma + self.execution_probability = execution_probability + + def __call__(self, m): + if self.random_state.uniform() < self.execution_probability: + # rescale intensity values to [0, 1] + m = exposure.rescale_intensity(m, out_range=(0, 1)) + gamma = self.random_state.uniform(self.gamma[0], self.gamma[1]) + return exposure.adjust_gamma(m, gamma) + + return m + + # it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader # remember to use spline_order=0 when transforming the labels class ElasticDeformation: @@ -576,12 +597,12 @@ def __call__(self, m): # check if non None in self.min_value/self.max_value # if present and if so copy value to min_value if self.min_value is not None: - for i,v in enumerate(self.min_value): + for i, v in enumerate(self.min_value): if v != 'None': min_value[i] = v if self.max_value is not None: - for i,v in enumerate(self.max_value): + for i, v in enumerate(self.max_value): if v != 'None': max_value[i] = v else: @@ -600,9 +621,9 @@ def __call__(self, m): norm_0_1 = (m - min_value) / (max_value - min_value + self.eps) if self.norm01 is True: - return np.clip(norm_0_1, 0, 1) + return np.clip(norm_0_1, 0, 1) else: - return np.clip(2 * norm_0_1 - 1, -1, 1) + return np.clip(2 * norm_0_1 - 1, -1, 1) class AdditiveGaussianNoise: @@ -640,11 +661,13 @@ class ToTensor: Args: expand_dims (bool): if True, adds a channel dimension to the input data dtype (np.dtype): the desired output data type + normalize (bool): zero-one normalization of the input data """ - def __init__(self, expand_dims, dtype=np.float32, **kwargs): + def __init__(self, expand_dims, dtype=np.float32, normalize=False, **kwargs): self.expand_dims = expand_dims self.dtype = dtype + self.normalize = normalize def __call__(self, m): assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images' @@ -652,6 +675,10 @@ def __call__(self, m): if self.expand_dims and m.ndim == 3: m = np.expand_dims(m, axis=0) + if self.normalize: + # avoid division by zero + m = (m - np.min(m)) / (np.max(m) - np.min(m) + 1e-10) + return torch.from_numpy(m.astype(dtype=self.dtype)) @@ -706,7 +733,7 @@ def __call__(self, m): class GaussianBlur3D: - def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs): + def __init__(self, sigma=(.1, 2.), execution_probability=0.5, **kwargs): self.sigma = sigma self.execution_probability = execution_probability diff --git a/pytorch3dunet/datasets/hdf5.py b/pytorch3dunet/datasets/hdf5.py index 040adb85..e5ec1176 100644 --- a/pytorch3dunet/datasets/hdf5.py +++ b/pytorch3dunet/datasets/hdf5.py @@ -1,12 +1,13 @@ import glob import os from abc import abstractmethod +from concurrent.futures.process import ProcessPoolExecutor from itertools import chain import h5py import pytorch3dunet.augment.transforms as transforms -from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad +from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad, RandomScaler from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('HDF5Dataset') @@ -44,10 +45,14 @@ class AbstractHDF5Dataset(ConfigDataset): label_internal_path (str or list): H5 internal path to the label dataset weight_internal_path (str or list): H5 internal path to the per pixel weights (optional) global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset + random_scale (int): if not None, the raw data will be randomly shifted by a value in the range + [-random_scale, random_scale] in each dimension and then scaled to the original patch shape + random_scale_probability (float): probability of executing the random scale on a patch """ def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', - label_internal_path='label', weight_internal_path=None, global_normalization=True): + label_internal_path='label', weight_internal_path=None, global_normalization=True, + random_scale=None, random_scale_probability=0.5): assert phase in ['train', 'val', 'test'] self.phase = phase @@ -94,6 +99,10 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r with h5py.File(file_path, 'r') as f: raw = f[raw_internal_path] + if raw.ndim == 3: + self.volume_shape = raw.shape + else: + self.volume_shape = raw.shape[1:] label = f[label_internal_path] if phase != 'test' else None weight_map = f[weight_internal_path] if weight_internal_path is not None else None # build slice indices for raw and label data sets @@ -102,8 +111,18 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r self.label_slices = slice_builder.label_slices self.weight_slices = slice_builder.weight_slices + if random_scale is not None: + assert isinstance(random_scale, int), 'random_scale must be an integer' + stride_shape = slice_builder_config.get('stride_shape') + assert all(random_scale < stride for stride in stride_shape), \ + f"random_scale {random_scale} must be smaller than each of the strides {stride_shape}" + patch_shape = slice_builder_config.get('patch_shape') + self.random_scaler = RandomScaler(random_scale, patch_shape, self.volume_shape, random_scale_probability) + logger.info(f"Using RandomScaler with offset range {random_scale}") + else: + self.random_scaler = None + self.patch_count = len(self.raw_slices) - logger.info(f'Number of patches: {self.patch_count}') @abstractmethod def get_raw_patch(self, idx): @@ -121,14 +140,6 @@ def get_weight_patch(self, idx): def get_raw_padded_patch(self, idx): raise NotImplementedError - def volume_shape(self): - with h5py.File(self.file_path, 'r') as f: - raw = f[self.raw_internal_path] - if raw.ndim == 3: - return raw.shape - else: - return raw.shape[1:] - def __getitem__(self, idx): if idx >= len(self): raise StopIteration @@ -146,15 +157,24 @@ def __getitem__(self, idx): raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded)) return raw_patch_transformed, raw_idx else: - raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx)) - - # get the slice for a given index 'idx' label_idx = self.label_slices[idx] + + if self.random_scaler is not None: + # randomize the indices + raw_idx, label_idx = self.random_scaler.randomize_indices(raw_idx, label_idx) + + raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx)) label_patch_transformed = self.label_transform(self.get_label_patch(label_idx)) if self.weight_internal_path is not None: weight_idx = self.weight_slices[idx] weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx)) return raw_patch_transformed, label_patch_transformed, weight_patch_transformed + + if self.random_scaler is not None: + # scale patches back to the original patch size + raw_patch_transformed, label_patch_transformed = self.random_scaler.rescale_patches( + raw_patch_transformed, label_patch_transformed + ) # return the transformed raw and label patches return raw_patch_transformed, label_patch_transformed @@ -192,22 +212,31 @@ def create_datasets(cls, dataset_config, phase): # are going to be included in the final file_paths file_paths = traverse_h5_paths(file_paths) - datasets = [] - for file_path in file_paths: - try: + # create datasets concurrently + with ProcessPoolExecutor() as executor: + futures = [] + for file_path in file_paths: logger.info(f'Loading {phase} set from: {file_path}...') - dataset = cls(file_path=file_path, - phase=phase, - slice_builder_config=slice_builder_config, - transformer_config=transformer_config, - raw_internal_path=dataset_config.get('raw_internal_path', 'raw'), - label_internal_path=dataset_config.get('label_internal_path', 'label'), - weight_internal_path=dataset_config.get('weight_internal_path', None), - global_normalization=dataset_config.get('global_normalization', None)) - datasets.append(dataset) - except Exception: - logger.error(f'Skipping {phase} set: {file_path}', exc_info=True) - return datasets + future = executor.submit(cls, file_path=file_path, + phase=phase, + slice_builder_config=slice_builder_config, + transformer_config=transformer_config, + raw_internal_path=dataset_config.get('raw_internal_path', 'raw'), + label_internal_path=dataset_config.get('label_internal_path', 'label'), + weight_internal_path=dataset_config.get('weight_internal_path', None), + global_normalization=dataset_config.get('global_normalization', None), + random_scale=dataset_config.get('random_scale', None), + random_scale_probability=dataset_config.get('random_scale_probability', 0.5)) + futures.append(future) + + datasets = [] + for future in futures: + try: + dataset = future.result() + datasets.append(dataset) + except Exception as e: + logger.error(f'Failed to load dataset: {e}') + return datasets class StandardHDF5Dataset(AbstractHDF5Dataset): @@ -218,11 +247,12 @@ class StandardHDF5Dataset(AbstractHDF5Dataset): def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, - global_normalization=True): + global_normalization=True, random_scale=None, random_scale_probability=0.5): super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, transformer_config=transformer_config, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, - global_normalization=global_normalization) + global_normalization=global_normalization, random_scale=random_scale, + random_scale_probability=random_scale_probability) self._raw = None self._raw_padded = None self._label = None @@ -262,11 +292,12 @@ class LazyHDF5Dataset(AbstractHDF5Dataset): def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, - global_normalization=False): + global_normalization=False, random_scale=None, random_scale_probability=0.5): super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, transformer_config=transformer_config, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path, - global_normalization=global_normalization) + global_normalization=global_normalization, random_scale=random_scale, + random_scale_probability=random_scale_probability) logger.info("Using LazyHDF5Dataset") diff --git a/pytorch3dunet/datasets/utils.py b/pytorch3dunet/datasets/utils.py index 1ffeefe4..263c00b8 100644 --- a/pytorch3dunet/datasets/utils.py +++ b/pytorch3dunet/datasets/utils.py @@ -3,6 +3,7 @@ import numpy as np import torch +from torch.nn.functional import interpolate from torch.utils.data import DataLoader, ConcatDataset, Dataset from pytorch3dunet.unet3d.utils import get_logger, get_class @@ -10,7 +11,133 @@ logger = get_logger('Dataset') +class RandomScaler: + """ + Randomly scales the raw and label patches. + """ + + def __init__(self, scale_range: int, patch_shape: tuple, volume_shape: tuple, execution_probability: bool = 0.5, + seed: int = 47): + self.scale_range = scale_range + self.patch_shape = patch_shape + self.volume_shape = volume_shape + self.execution_probability = execution_probability + self.rs = np.random.RandomState(seed) + + def randomize_indices(self, raw_idx: tuple, label_idx: tuple) -> tuple[tuple, tuple]: + # execute scaling with a given probability + if self.rs.uniform() < self.execution_probability: + return raw_idx, label_idx + + # select random offsets for scaling + offsets = [self.rs.randint(self.scale_range) for _ in range(3)] + # change offset sign at random + if self.rs.rand() > 0.5: + offsets = [-o for o in offsets] + # apply offsets to the start or end of the slice at random + is_start = self.rs.rand() > 0.5 + raw_idx = self._apply_offsets(raw_idx, offsets, is_start) + label_idx = self._apply_offsets(label_idx, offsets, is_start) + + # assert spatial dimensions are the same + if len(raw_idx) == 4: + raw_idx_spacial = raw_idx[1:] + else: + raw_idx_spacial = raw_idx + if len(label_idx) == 4: + label_idx_spacial = label_idx[1:] + else: + label_idx_spacial = label_idx + assert raw_idx_spacial == label_idx_spacial, f"Raw and label indices are different: {raw_idx_spacial} != {label_idx_spacial}" + + return raw_idx, label_idx + + def rescale_patches(self, raw_patch: torch.Tensor, label_patch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # compute zoom factors + if raw_patch.ndim == 4: + raw_shape = raw_patch.shape[1:] + else: + raw_shape = raw_patch.shape + + # if raw_shape equal to self.patch_shape just return the patches + if raw_shape == self.patch_shape: + return raw_patch, label_patch + + # rescale patches back to the original shape + if raw_patch.ndim == 4: + # add batch dimension + raw_patch = raw_patch.unsqueeze(0) + remove_dims = 1 + else: + # add batch and channels dimensions + raw_patch = raw_patch.unsqueeze(0).unsqueeze(0) + remove_dims = 2 + + # interpolate raw patch + raw_patch = interpolate(raw_patch, self.patch_shape, mode='trilinear') + # remove additional dimensions + for _ in range(remove_dims): + raw_patch = raw_patch.squeeze(0) + + if label_patch.ndim == 4: + label_patch = label_patch.unsqueeze(0) + remove_dims = 1 + else: + label_patch = label_patch.unsqueeze(0).unsqueeze(0) + remove_dims = 2 + + label_dtype = label_patch.dtype + # check if label patch is of torch int type + if label_dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]: + # convert to float for interpolation + label_patch = label_patch.float() + + # interpolate label patch + label_patch = interpolate(label_patch, self.patch_shape, mode='nearest') + + # remove additional dimensions + for _ in range(remove_dims): + label_patch = label_patch.squeeze(0) + + # convert back to int if necessary + if label_dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]: + if label_dtype == torch.int64: + label_patch = label_patch.long() + else: + label_patch = label_patch.int() + + return raw_patch, label_patch + + def _apply_offsets(self, idx: tuple, offsets: list, is_start: bool) -> tuple: + if len(idx) == 4: + spatial_idx = idx[1:] + else: + spatial_idx = idx + + new_idx = [] + for i, o, s in zip(spatial_idx, offsets, self.volume_shape): + if is_start: + # prevent negative start + start = max(0, i.start + o) + stop = i.stop + else: + start = i.start + # prevent stop exceeding the volume shape + stop = min(s, i.stop + o) + + new_idx.append(slice(start, stop)) + + if len(idx) == 4: + return (idx[0],) + tuple(new_idx) + + return tuple(new_idx) + + class ConfigDataset(Dataset): + """ + Abstract class for datasets that are configured via a dictionary. + """ + def __getitem__(self, index): raise NotImplementedError @@ -155,8 +282,12 @@ def ignore_predicate(raw_label_idx): zipped_slices = zip(self.raw_slices, self.label_slices) # ignore slices containing too much ignore_index - logger.info(f'Filtering slices...') filtered_slices = list(filter(ignore_predicate, zipped_slices)) + # log number of filtered patches + logger.info( + f"Loading {len(filtered_slices)} out of {len(self.raw_slices)} patches: " + f"{int(100 * len(filtered_slices) / len(self.raw_slices))}%" + ) # unzip and save slices raw_slices, label_slices = zip(*filtered_slices) self._raw_slices = list(raw_slices) @@ -220,10 +351,10 @@ def get_train_loaders(config): # when training with volumetric data use batch_size of 1 due to GPU memory constraints return { 'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, pin_memory=True, - num_workers=num_workers), + num_workers=num_workers, drop_last=True), # don't shuffle during validation: useful when showing how predictions for a given batch get better over time 'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, pin_memory=True, - num_workers=num_workers) + num_workers=num_workers, drop_last=True) } diff --git a/pytorch3dunet/predict.py b/pytorch3dunet/predict.py index cc54fcf7..2340f049 100755 --- a/pytorch3dunet/predict.py +++ b/pytorch3dunet/predict.py @@ -14,8 +14,6 @@ def get_predictor(model, config): output_dir = config['loaders'].get('output_dir', None) - # override output_dir if provided in the 'predictor' section of the config - output_dir = config.get('predictor', {}).get('output_dir', output_dir) if output_dir is not None: os.makedirs(output_dir, exist_ok=True) diff --git a/pytorch3dunet/unet3d/losses.py b/pytorch3dunet/unet3d/losses.py index 6a53966f..a76c733c 100644 --- a/pytorch3dunet/unet3d/losses.py +++ b/pytorch3dunet/unet3d/losses.py @@ -3,6 +3,9 @@ from torch import nn as nn from torch.nn import MSELoss, SmoothL1Loss, L1Loss +from pytorch3dunet.unet3d.utils import get_logger + +logger = get_logger('Loss') def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): """ @@ -167,15 +170,14 @@ def dice(self, input, target, weight): class BCEDiceLoss(nn.Module): """Linear combination of BCE and Dice losses""" - def __init__(self, alpha, beta): + def __init__(self, alpha=1.0): super(BCEDiceLoss, self).__init__() self.alpha = alpha self.bce = nn.BCEWithLogitsLoss() - self.beta = beta self.dice = DiceLoss() def forward(self, input, target): - return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target) + return self.bce(input, target) + self.alpha * self.dice(input, target) class WeightedCrossEntropyLoss(nn.Module): @@ -279,13 +281,15 @@ def get_loss_criterion(config): assert 'loss' in config, 'Could not find loss function configuration' loss_config = config['loss'] name = loss_config.pop('name') + logger.info(f"Creating loss function: {name}") ignore_index = loss_config.pop('ignore_index', None) skip_last_target = loss_config.pop('skip_last_target', False) weight = loss_config.pop('weight', None) if weight is not None: - weight = torch.tensor(weight) + weight = torch.tensor(weight).float() + logger.info(f"Using class weights: {weight}") pos_weight = loss_config.pop('pos_weight', None) if pos_weight is not None: @@ -313,8 +317,7 @@ def _create_loss(name, loss_config, weight, ignore_index, pos_weight): return nn.BCEWithLogitsLoss(pos_weight=pos_weight) elif name == 'BCEDiceLoss': alpha = loss_config.get('alpha', 1.) - beta = loss_config.get('beta', 1.) - return BCEDiceLoss(alpha, beta) + return BCEDiceLoss(alpha) elif name == 'CrossEntropyLoss': if ignore_index is None: ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss diff --git a/pytorch3dunet/unet3d/metrics.py b/pytorch3dunet/unet3d/metrics.py index 6764eec9..07472241 100644 --- a/pytorch3dunet/unet3d/metrics.py +++ b/pytorch3dunet/unet3d/metrics.py @@ -7,7 +7,7 @@ from pytorch3dunet.unet3d.losses import compute_per_channel_dice from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy -from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy +from pytorch3dunet.unet3d.utils import get_logger, convert_to_numpy logger = get_logger('EvalMetric') @@ -32,15 +32,13 @@ def __call__(self, input, target): class MeanIoU: """ Computes IoU for each class separately and then averages over all classes. + + Args: + skip_background (bool): if True, background class (i.e. 0-label) will be skipped when computing IoU """ - def __init__(self, skip_channels=(), ignore_index=None, **kwargs): - """ - :param skip_channels: list/tuple of channels to be ignored from the IoU computation - :param ignore_index: id of the label to be ignored from IoU computation - """ - self.ignore_index = ignore_index - self.skip_channels = skip_channels + def __init__(self, skip_background=True, **kwargs): + self.skip_background = skip_background def __call__(self, input, target): """ @@ -50,58 +48,46 @@ def __call__(self, input, target): """ assert input.dim() == 5 - n_classes = input.size()[1] + n_classes = input.size(1) if target.dim() == 4: - target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) + # convert input to segmentation + input = input.argmax(dim=1) assert input.size() == target.size() per_batch_iou = [] for _input, _target in zip(input, target): - binary_prediction = self._binarize_predictions(_input, n_classes) - - if self.ignore_index is not None: - # zero out ignore_index - mask = _target == self.ignore_index - binary_prediction[mask] = 0 - _target[mask] = 0 - - # convert to uint8 just in case - binary_prediction = binary_prediction.byte() + # convert target to byte _target = _target.byte() - per_channel_iou = [] - for c in range(n_classes): - if c in self.skip_channels: - continue - - per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c])) + start_idx = 0 + # skip background only if target is 4D; for channel-wise computation (i.e. if target is 5D) we need to include it + if self.skip_background and target.dim() == 4: + start_idx = 1 + + for c in range(start_idx, n_classes): + if target.dim() == 5: + iou = self._jaccard_index(_input[c] > 0.5, _target[c]) + per_channel_iou.append(iou) + else: + iou = self._jaccard_index(_input == c, _target == c) + per_channel_iou.append(iou) assert per_channel_iou, "All channels were ignored from the computation" - mean_iou = torch.mean(torch.tensor(per_channel_iou)) + mean_iou = torch.tensor(per_channel_iou).mean() per_batch_iou.append(mean_iou) - return torch.mean(torch.tensor(per_batch_iou)) - - def _binarize_predictions(self, input, n_classes): - """ - Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the - same size as the input tensor. - """ - if n_classes == 1: - # for single channel input just threshold the probability map - result = input > 0.5 - return result.long() - - _, max_index = torch.max(input, dim=0, keepdim=True) - return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) + return torch.tensor(per_batch_iou).mean() def _jaccard_index(self, prediction, target): """ Computes IoU for a given target and prediction tensors """ - return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8) + epsilon = 1e-8 + intersection = torch.logical_and(target, prediction).sum() + union = torch.logical_or(target, prediction).sum() + return (intersection + epsilon) / (union + epsilon) class AdaptedRandError: diff --git a/pytorch3dunet/unet3d/model.py b/pytorch3dunet/unet3d/model.py index e4de49a7..4cf3ddd7 100644 --- a/pytorch3dunet/unet3d/model.py +++ b/pytorch3dunet/unet3d/model.py @@ -1,4 +1,4 @@ -import torch.nn as nn +from torch import nn from pytorch3dunet.unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \ create_decoders, create_encoders @@ -81,7 +81,27 @@ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_map # regression problem self.final_activation = None - def forward(self, x): + def forward(self, x, return_logits=False): + """ + Forward pass through the network. + + Args: + x (torch.Tensor): Input tensor of shape (N, C, D, H, W) for 3D or (N, C, H, W) for 2D, + where N is the batch size, C is the number of channels, + D is the depth, H is the height, and W is the width. + return_logits (bool): If True, returns both the output and the logits. + If False, returns only the output. Default is False. + + Returns: + torch.Tensor: The output tensor after passing through the network. + If return_logits is True, returns a tuple of (output, logits). + """ + output, logits = self._forward_logits(x) + if return_logits: + return output, logits + return output + + def _forward_logits(self, x): # encoder part encoders_features = [] for encoder in self.encoders: @@ -101,12 +121,13 @@ def forward(self, x): x = self.final_conv(x) - # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. - # During training the network outputs logits - if not self.training and self.final_activation is not None: - x = self.final_activation(x) + if self.final_activation is not None: + # compute final activation + out = self.final_activation(x) + # return both probabilities and logits + return out, x - return x + return x, x class UNet3D(AbstractUNet): @@ -247,3 +268,9 @@ def get_model(model_config): 'pytorch3dunet.unet3d.model' ]) return model_class(**model_config) + + +def is_model_2d(model): + if isinstance(model, nn.DataParallel): + model = model.module + return isinstance(model, UNet2D) diff --git a/pytorch3dunet/unet3d/predictor.py b/pytorch3dunet/unet3d/predictor.py index c9b4f6eb..d2cb2025 100644 --- a/pytorch3dunet/unet3d/predictor.py +++ b/pytorch3dunet/unet3d/predictor.py @@ -12,24 +12,32 @@ from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset from pytorch3dunet.datasets.utils import SliceBuilder, remove_padding -from pytorch3dunet.unet3d.model import UNet2D from pytorch3dunet.unet3d.utils import get_logger +from pytorch3dunet.unet3d.model import is_model_2d logger = get_logger('UNetPredictor') -def _get_output_file(dataset, suffix='_predictions', output_dir=None): - input_dir, file_name = os.path.split(dataset.file_path) +def _get_output_file(dataset: AbstractHDF5Dataset, suffix: str = '_predictions', output_dir: str = None) -> Path: + """ + Get the output file path for the predictions. + Args: + dataset: input dataset + suffix: file name suffix + output_dir: directory where the output file will be saved + + Returns: + path to the output file + """ + file_path = Path(dataset.file_path) + input_dir = file_path.parent if output_dir is None: output_dir = input_dir - output_filename = os.path.splitext(file_name)[0] + suffix + '.h5' - return Path(output_dir) / output_filename - + else: + output_dir = Path(output_dir) -def _is_2d_model(model): - if isinstance(model, nn.DataParallel): - model = model.module - return isinstance(model, UNet2D) + output_filename = file_path.stem + suffix + '.h5' + return Path(output_dir) / output_filename class _AbstractPredictor: @@ -89,19 +97,24 @@ def __call__(self, test_loader): logger.info(f'Running inference on {len(test_loader)} batches') # dimensionality of the output predictions - volume_shape = test_loader.dataset.volume_shape() - if self.prediction_channel is not None: - # single channel prediction map - prediction_maps_shape = (1,) + volume_shape + volume_shape = test_loader.dataset.volume_shape + + if self.save_segmentation: + # single channel segmentation map + prediction_shape = volume_shape else: - prediction_maps_shape = (self.out_channels,) + volume_shape + if self.prediction_channel is not None: + # single channel prediction map + prediction_shape = (1,) + volume_shape + else: + prediction_shape = (self.out_channels,) + volume_shape # create destination H5 file output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir) with h5py.File(output_file, 'w') as h5_output_file: # allocate prediction and normalization arrays logger.info('Allocating prediction and normalization arrays...') - prediction_map, normalization_mask = self._allocate_prediction_maps(prediction_maps_shape, h5_output_file) + prediction_array = self._allocate_prediction_array(prediction_shape, h5_output_file) # determine halo used for padding patch_halo = test_loader.dataset.halo_shape @@ -116,7 +129,7 @@ def __call__(self, test_loader): if torch.cuda.is_available(): input = input.pin_memory().cuda(non_blocking=True) - if _is_2d_model(self.model): + if is_model_2d(self.model): # remove the singleton z-dimension from the input input = torch.squeeze(input, dim=-3) # forward pass @@ -133,39 +146,43 @@ def __call__(self, test_loader): prediction = prediction.cpu().numpy() # for each batch sample for pred, index in zip(prediction, indices): - # save patch index: (C,D,H,W) - if self.prediction_channel is None: - channel_slice = slice(0, self.out_channels) + + if self.save_segmentation: + # if single channel, binarize + if pred.shape[0] == 1: + pred = pred[0] > 0.5 + else: + # use the argmax of the prediction + pred = np.argmax(pred, axis=0) + pred = pred.astype('uint16') + index = tuple(index) else: - # use only the specified channel - channel_slice = slice(0, 1) - pred = np.expand_dims(pred[self.prediction_channel], axis=0) + # save patch index: (C,D,H,W) + if self.prediction_channel is None: + channel_slice = slice(0, self.out_channels) + else: + # use only the specified channel + channel_slice = slice(0, 1) + pred = np.expand_dims(pred[self.prediction_channel], axis=0) + # add channel dimension to the index + index = (channel_slice,) + tuple(index) - # add channel dimension to the index - index = (channel_slice,) + tuple(index) # accumulate probabilities into the output prediction array - prediction_map[index] += pred - # count voxel visits for normalization - normalization_mask[index] += 1 + prediction_array[index] = pred logger.info(f'Finished inference in {time.perf_counter() - start:.2f} seconds') # save results output_type = 'segmentation' if self.save_segmentation else 'probability maps' logger.info(f'Saving {output_type} to: {output_file}') - self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset) - - def _allocate_prediction_maps(self, output_shape, output_file): - # initialize the output prediction arrays - prediction_map = np.zeros(output_shape, dtype='float32') - # initialize normalization mask in order to average out probabilities of overlapping patches - normalization_mask = np.zeros(output_shape, dtype='uint8') - return prediction_map, normalization_mask + h5_output_file.create_dataset(self.output_dataset, data=prediction_array, compression="gzip") - def _save_results(self, prediction_map, normalization_mask, output_file, dataset): - result = prediction_map / normalization_mask + def _allocate_prediction_array(self, output_shape, output_file): if self.save_segmentation: - result = np.argmax(result, axis=0).astype('uint16') - output_file.create_dataset(self.output_dataset, data=result, compression="gzip") + dtype = 'uint16' + else: + dtype = 'float32' + # initialize the output prediction arrays + return np.zeros(output_shape, dtype=dtype) class LazyPredictor(StandardPredictor): @@ -186,7 +203,7 @@ def __init__(self, super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel, **kwargs) - def _allocate_prediction_maps(self, output_shape, output_file): + def _allocate_prediction_array(self, output_shape, output_file): # allocate datasets for probability maps prediction_map = output_file.create_dataset(self.output_dataset, shape=output_shape, @@ -201,22 +218,22 @@ def _allocate_prediction_maps(self, output_shape, output_file): compression='gzip') return prediction_map, normalization_mask - def _save_results(self, prediction_map, normalization_mask, output_file, dataset): - z, y, x = prediction_map.shape[1:] + def _save_results(self, prediction_array, normalization_mask, output_file, dataset): + z, y, x = prediction_array.shape[1:] # take slices which are 1/27 of the original volume patch_shape = (z // 3, y // 3, x // 3) if self.save_segmentation: output_file.create_dataset('segmentation', shape=(z, y, x), dtype='uint16', chunks=True, compression='gzip') - for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape): + for index in SliceBuilder._build_slices(prediction_array, patch_shape=patch_shape, stride_shape=patch_shape): logger.info(f'Normalizing slice: {index}') - prediction_map[index] /= normalization_mask[index] + prediction_array[index] /= normalization_mask[index] # make sure to reset the slice that has been visited already in order to avoid 'double' normalization # when the patches overlap with each other normalization_mask[index] = 1 # save segmentation if self.save_segmentation: - output_file['segmentation'][index[1:]] = np.argmax(prediction_map[index], axis=0).astype('uint16') + output_file['segmentation'][index[1:]] = np.argmax(prediction_array[index], axis=0).astype('uint16') del output_file['normalization'] if self.save_segmentation: diff --git a/pytorch3dunet/unet3d/trainer.py b/pytorch3dunet/unet3d/trainer.py index e3a12062..8c944c92 100644 --- a/pytorch3dunet/unet3d/trainer.py +++ b/pytorch3dunet/unet3d/trainer.py @@ -1,14 +1,18 @@ import os +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime + +import numpy as np import torch import torch.nn as nn from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.tensorboard import SummaryWriter -from datetime import datetime +from tqdm import tqdm from pytorch3dunet.datasets.utils import get_train_loaders from pytorch3dunet.unet3d.losses import get_loss_criterion from pytorch3dunet.unet3d.metrics import get_evaluation_metric -from pytorch3dunet.unet3d.model import get_model, UNet2D +from pytorch3dunet.unet3d.model import get_model, is_model_2d from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \ create_lr_scheduler, get_number_of_learnable_parameters from . import utils @@ -82,15 +86,19 @@ class UNetTrainer: num_epoch (int): useful when loading the model from the checkpoint tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images that can be displayed in tensorboard - skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when + skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used when evaluation is expensive) + resume (string): path to the checkpoint to be resumed + pre_trained (string): path to the pre-trained model + max_val_images (int): maximum number of images to log during validation """ def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir, max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None, num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None, - skip_train_validation=False, resume=None, pre_trained=None, **kwargs): + skip_train_validation=False, resume=None, pre_trained=None, max_val_images=100, **kwargs): + self.max_val_images = max_val_images self.model = model self.optimizer = optimizer self.scheduler = lr_scheduler @@ -116,10 +124,10 @@ def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterio self.writer = SummaryWriter( log_dir=os.path.join( - checkpoint_dir, 'logs', + checkpoint_dir, 'logs', datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - ) ) + ) assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided' self.tensorboard_formatter = tensorboard_formatter @@ -209,24 +217,12 @@ def train(self): if self.num_iterations % self.log_after_iters == 0: # compute eval criterion if not self.skip_train_validation: - # apply final activation before calculating eval score - if isinstance(self.model, nn.DataParallel): - final_activation = self.model.module.final_activation - else: - final_activation = self.model.final_activation - - if final_activation is not None: - act_output = final_activation(output) - else: - act_output = output - eval_score = self.eval_criterion(act_output, target) + eval_score = self.eval_criterion(output, target) train_eval_scores.update(eval_score.item(), self._batch_size(input)) - # log stats, params and images logger.info( f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}') self._log_stats('train', train_losses.avg, train_eval_scores.avg) - # self._log_params() self._log_images(input, target, output, 'train_') if self.should_stop(): @@ -260,16 +256,23 @@ def validate(self): val_scores = utils.RunningAverage() with torch.no_grad(): - for i, t in enumerate(self.loaders['val']): - logger.info(f'Validation iteration {i}') + # select indices of validation samples to log + rs = np.random.RandomState(42) + if len(self.loaders['val']) <= self.max_val_images: + indices = list(range(len(self.loaders['val']))) + else: + indices = rs.choice(len(self.loaders['val']), size=self.max_val_images, replace=False) + images_for_logging = [] + for i, t in enumerate(tqdm(self.loaders['val'])): input, target, weight = self._split_training_batch(t) output, loss = self._forward_pass(input, target, weight) val_losses.update(loss.item(), self._batch_size(input)) - if i % 100 == 0: - self._log_images(input, target, output, 'val_') + # save val images for logging + if i in indices: + images_for_logging.append((input, target, output, i)) eval_score = self.eval_criterion(output, target) val_scores.update(eval_score.item(), self._batch_size(input)) @@ -278,8 +281,13 @@ def validate(self): # stop validation break - self._log_stats('val', val_losses.avg, val_scores.avg) + # log images in a separate thread + with ThreadPoolExecutor() as executor: + for input, target, output, i in images_for_logging: + executor.submit(self._log_images, input, target, output, f'val_{i}_') + logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}') + self._log_stats('val', val_losses.avg, val_scores.avg) return val_scores.avg def _split_training_batch(self, t): @@ -299,24 +307,26 @@ def _move_to_gpu(input): input, target, weight = t return input, target, weight - def _forward_pass(self, input, target, weight=None): - if isinstance(self.model, UNet2D): + def _forward_pass(self, x, y, weight=None): + if is_model_2d(self.model): # remove the singleton z-dimension from the input - input = torch.squeeze(input, dim=-3) + x = torch.squeeze(x, dim=-3) # forward pass - output = self.model(input) + output, logits = self.model(x, return_logits=True) # add the singleton z-dimension to the output output = torch.unsqueeze(output, dim=-3) + logits = torch.unsqueeze(logits, dim=-3) else: # forward pass - output = self.model(input) + output, logits = self.model(x, return_logits=True) - # compute the loss + # always compute the loss using logits if weight is None: - loss = self.loss_criterion(output, target) + loss = self.loss_criterion(logits, y) else: - loss = self.loss_criterion(output, target, weight) + loss = self.loss_criterion(logits, y, weight) + # return probabilities and loss return output, loss def _is_best_eval_score(self, eval_score): @@ -369,16 +379,7 @@ def _log_params(self): self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations) self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations) - def _log_images(self, input, target, prediction, prefix=''): - - if isinstance(self.model, nn.DataParallel): - net = self.model.module - else: - net = self.model - - if net.final_activation is not None: - prediction = net.final_activation(prediction) - + def _log_images(self, input, target, prediction, prefix): inputs_map = { 'inputs': input, 'targets': target, diff --git a/pytorch3dunet/unet3d/utils.py b/pytorch3dunet/unet3d/utils.py index 01d5559c..d6b4609a 100644 --- a/pytorch3dunet/unet3d/utils.py +++ b/pytorch3dunet/unet3d/utils.py @@ -7,6 +7,7 @@ import h5py import numpy as np import torch +from skimage.color import label2rgb from torch import optim @@ -110,15 +111,16 @@ def number_of_features_per_level(init_channel_number, num_levels): return [init_channel_number * 2 ** k for k in range(num_levels)] -class _TensorboardFormatter: +class TensorboardFormatter: """ Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard formatters which ensures that returned images are in the 'CHW' format. """ - def __init__(self, **kwargs): - pass + def __init__(self, skip_last_target=False, log_channelwise=False): + self.skip_last_target = skip_last_target + self.log_channelwise = log_channelwise def __call__(self, name, batch): """ @@ -128,6 +130,9 @@ def __call__(self, name, batch): Args: name (str): one of 'inputs'/'targets'/'predictions' batch (torch.tensor): 4D or 5D torch tensor + + Returns: + list[(str, np.ndarray)]: list of tuples of the form (tag, img) """ def _check_img(tag_img): @@ -143,24 +148,15 @@ def _check_img(tag_img): return tag, img - tagged_images = self.process_batch(name, batch) + tagged_images = self._process_batch(name, batch) return list(map(_check_img, tagged_images)) - def process_batch(self, name, batch): - raise NotImplementedError - - -class DefaultTensorboardFormatter(_TensorboardFormatter): - def __init__(self, skip_last_target=False, **kwargs): - super().__init__(**kwargs) - self.skip_last_target = skip_last_target - - def process_batch(self, name, batch): + def _process_batch(self, name, batch): if name == 'targets' and self.skip_last_target: batch = batch[:, :-1, ...] - tag_template = '{}/batch_{}/channel_{}/slice_{}' + tag_template = '{}/batch_{}/slice_{}' tagged_images = [] @@ -168,17 +164,52 @@ def process_batch(self, name, batch): # NCDHW slice_idx = batch.shape[2] // 2 # get the middle slice for batch_idx in range(batch.shape[0]): - for channel_idx in range(batch.shape[1]): - tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) - img = batch[batch_idx, channel_idx, slice_idx, ...] - tagged_images.append((tag, self._normalize_img(img))) + if self.log_channelwise and name == 'predictions': + tag_template = '{}/batch_{}/channel_{}/slice_{}' + for channel_idx in range(batch.shape[1]): + tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) + img = batch[batch_idx, channel_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: + tag = tag_template.format(name, batch_idx, slice_idx) + if name in ['predictions', 'targets']: + # for single channel predictions, just log the image + if batch.shape[1] == 1: + img = batch[batch_idx, :, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: + # predictions are probabilities so convert to label image + img = batch[batch_idx].argmax(axis=0) + # take the middle slice + img = img[slice_idx, ...] + # convert to label image + img = label2rgb(img) + img = img.transpose(2, 0, 1) + tagged_images.append((tag, img)) + else: + # handle input images + if batch.shape[1] in [1, 3]: + # if single channel or RGB image, log directly + img = batch[batch_idx, :, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: + # log channelwise + tag_template = '{}/batch_{}/channel_{}/slice_{}' + for channel_idx in range(batch.shape[1]): + tag = tag_template.format(name, batch_idx, channel_idx, slice_idx) + img = batch[batch_idx, channel_idx, slice_idx, ...] + tagged_images.append((tag, self._normalize_img(img))) + else: # batch has no channel dim: NDHW slice_idx = batch.shape[1] // 2 # get the middle slice for batch_idx in range(batch.shape[0]): - tag = tag_template.format(name, batch_idx, 0, slice_idx) + tag = tag_template.format(name, batch_idx, slice_idx) img = batch[batch_idx, slice_idx, ...] - tagged_images.append((tag, self._normalize_img(img))) + # this is target segmentation so convert to label image + lbl = label2rgb(img) + lbl = lbl.transpose(2, 0, 1) + tagged_images.append((tag, lbl)) return tagged_images @@ -211,12 +242,8 @@ def _find_masks(batch, min_size=10): def get_tensorboard_formatter(formatter_config): if formatter_config is None: - return DefaultTensorboardFormatter() - - class_name = formatter_config['name'] - m = importlib.import_module('pytorch3dunet.unet3d.utils') - clazz = getattr(m, class_name) - return clazz(**formatter_config) + return TensorboardFormatter() + return TensorboardFormatter(**formatter_config) def expand_as_one_hot(input, C, ignore_index=None): diff --git a/resources/3DUnet_confocal_boundary/train_config.yml b/resources/3DUnet_confocal_boundary/train_config.yml index 42a4111e..5ef8574f 100644 --- a/resources/3DUnet_confocal_boundary/train_config.yml +++ b/resources/3DUnet_confocal_boundary/train_config.yml @@ -65,11 +65,15 @@ trainer: # Configure training and validation loaders loaders: # how many subprocesses to use for data loading - num_workers: 8 + num_workers: 32 # path to the raw data within the H5 - raw_internal_path: /raw - # path to the the label data withtin the H5 - label_internal_path: /label + raw_internal_path: raw + # path to the label data within the H5 + label_internal_path: label + # apply random shifting and scaling of the patches; value of 20 mean that patches may shrink/stretch by 20px in each dimension + random_scale: 20 + # random scale execution probability; since random scale is quite slow for 3D data, we set it to 0.1 + random_scale_probability: 0.1 # configuration of the train loader train: # path to the training datasets diff --git a/resources/3DUnet_lightsheet_boundary/train_config.yml b/resources/3DUnet_lightsheet_boundary/train_config.yml index f5ac2945..58495d3a 100644 --- a/resources/3DUnet_lightsheet_boundary/train_config.yml +++ b/resources/3DUnet_lightsheet_boundary/train_config.yml @@ -66,11 +66,15 @@ trainer: # Configure training and validation loaders loaders: # how many subprocesses to use for data loading - num_workers: 8 + num_workers: 32 # path to the raw data within the H5 - raw_internal_path: /raw - # path to the the label data withtin the H5 - label_internal_path: /label + raw_internal_path: raw + # path to the label data within the H5 + label_internal_path: label + # apply random shifting and scaling of the patches; value of 20 mean that patches may shrink/stretch by 20px in each dimension + random_scale: 20 + # random scale execution probability; since random scale is quite slow for 3D data, we set it to 0.1 + random_scale_probability: 0.1 # configuration of the train loader train: # path to the training datasets diff --git a/resources/3DUnet_lightsheet_nuclei/train_config.yaml b/resources/3DUnet_lightsheet_nuclei/train_config.yaml index 6a2c25a8..43207a6d 100644 --- a/resources/3DUnet_lightsheet_nuclei/train_config.yaml +++ b/resources/3DUnet_lightsheet_nuclei/train_config.yaml @@ -83,8 +83,12 @@ loaders: num_workers: 8 # path to the raw data within the H5 raw_internal_path: raw - # path to the the label data within the H5 + # path to the label data within the H5 label_internal_path: label + # apply random shifting and scaling of the patches; value of 20 mean that patches may shrink/stretch by 20px in each dimension + random_scale: 20 + # random scale execution probability; since random scale is quite slow for 3D data, we set it to 0.1 + random_scale_probability: 0.1 # path to the pixel-wise weight map withing the H5 if present weight_internal_path: null # configuration of the train loader diff --git a/resources/3DUnet_multiclass/train_config.yaml b/resources/3DUnet_multiclass/train_config.yaml index 9b1dc31f..86a3efae 100644 --- a/resources/3DUnet_multiclass/train_config.yaml +++ b/resources/3DUnet_multiclass/train_config.yaml @@ -74,6 +74,10 @@ loaders: raw_internal_path: raw # path to the label data within the H5 label_internal_path: label + # apply random shifting and scaling of the patches; value of 20 mean that patches may shrink/stretch by 20px in each dimension + random_scale: 20 + # random scale execution probability; since random scale is quite slow for 3D data, we set it to 0.1 + random_scale_probability: 0.1 # path to the pixel-wise weight map withing the H5 if present weight_internal_path: null # configuration of the train loader diff --git a/tests/resources/transformer_config.yml b/tests/resources/transformer_config.yml index f48c6e6d..1025a798 100644 --- a/tests/resources/transformer_config.yml +++ b/tests/resources/transformer_config.yml @@ -2,26 +2,16 @@ train: transformer: raw: - - name: Standardize - name: RandomFlip - name: RandomRotate90 - - name: RandomRotate - axes: [[2, 1]] - angle_spectrum: 30 - mode: reflect - name: ElasticDeformation execution_probability: 1.0 spline_order: 0 - name: ToTensor expand_dims: true label: - - name: Standardize - name: RandomFlip - name: RandomRotate90 - - name: RandomRotate - axes: [[2, 1]] - angle_spectrum: 30 - mode: reflect - name: ElasticDeformation execution_probability: 1.0 spline_order: 0 diff --git a/tests/test_criterion.py b/tests/test_criterion.py index 40a4b1d9..6036da14 100644 --- a/tests/test_criterion.py +++ b/tests/test_criterion.py @@ -51,28 +51,22 @@ def test_mean_iou_simple(self): assert np.all(results > 0) assert np.all(results < 1) - def test_mean_iou(self): + def test_mean_iou_multi_channel(self): criterion = MeanIoU() - x = torch.randn(3, 3, 3, 3) - _, index = torch.max(x, dim=0, keepdim=True) - # create target tensor - target = torch.zeros_like(x, dtype=torch.long).scatter_(0, index, 1) - pred = torch.zeros_like(target, dtype=torch.float) - mask = target == 1 - # create prediction tensor - pred[mask] = torch.rand(1) - # make sure the dimensions are right - target = torch.unsqueeze(target, dim=0) - pred = torch.unsqueeze(pred, dim=0) - assert criterion(pred, target) == 1 - - def test_mean_iou_one_channel(self): - criterion = MeanIoU() - pred = torch.rand(1, 1, 3, 3, 3) + pred = torch.rand(10, 3, 10, 10, 10) target = pred > 0.5 target = target.long() assert criterion(pred, target) == 1 + def test_mean_iou_multi_class(self): + criterion = MeanIoU() + n_classes = 5 + n_batch = 10 + pred = torch.rand(n_batch, n_classes, 10, 10, 10) + target = torch.randint(0, n_classes, (n_batch, 10, 10, 10)) + mean_iou = criterion(pred, target) + assert mean_iou >= 0 + def test_average_precision_synthethic_data(self): input = np.zeros((64, 200, 200), dtype=np.int32) for i in range(40, 200, 40): @@ -129,7 +123,7 @@ def test_dice_loss(self): assert np.all(results < 1) def test_bce_dice_loss(self): - results = _compute_criterion(BCEDiceLoss(1., 1.)) + results = _compute_criterion(BCEDiceLoss(1.)) results = np.array(results) assert np.all(results > 0) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 14ffbd73..e0107771 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -144,6 +144,30 @@ def test_halo(self): input_ = remove_padding(input_, halo_shape) assert np.allclose(input_[0], raw[indices]) + def test_random_scale(self, transformer_config): + path = create_random_dataset((200, 200, 172)) + + patch_shapes = [(172, 172, 172)] + stride_shapes = [(28, 28, 28)] + + phase = 'train' + + for patch_shape, stride_shape in zip(patch_shapes, stride_shapes): + dataset = StandardHDF5Dataset(path, phase=phase, + slice_builder_config=_slice_builder_conf(patch_shape, stride_shape), + transformer_config=transformer_config[phase]['transformer'], + raw_internal_path='raw', + label_internal_path='label', + random_scale=20) + + for raw, label in dataset: + if raw.ndim == 3: + assert raw.shape == patch_shape + assert label.shape == patch_shape + else: + assert raw.shape[1:] == patch_shape + assert label.shape[1:] == patch_shape + def create_random_dataset(shape, ignore_index=False, raw_datasets=None, label_datasets=None): if label_datasets is None: diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 5fbce5f0..8f596498 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -10,7 +10,7 @@ from pytorch3dunet.unet3d.metrics import get_evaluation_metric from pytorch3dunet.unet3d.model import get_model from pytorch3dunet.unet3d.trainer import UNetTrainer -from pytorch3dunet.unet3d.utils import DefaultTensorboardFormatter, create_optimizer, create_lr_scheduler +from pytorch3dunet.unet3d.utils import TensorboardFormatter, create_optimizer, create_lr_scheduler class TestUNet3DTrainer: @@ -97,7 +97,7 @@ def _train_save_load(tmpdir, train_config, loss, val_metric, model, weight_map, optimizer = create_optimizer(train_config['optimizer'], model) lr_scheduler = create_lr_scheduler(train_config.get('lr_scheduler', None), optimizer) - formatter = DefaultTensorboardFormatter() + formatter = TensorboardFormatter() trainer = UNetTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, tmpdir, max_num_epochs=train_config['trainer']['max_num_epochs'], max_num_iterations=train_config['trainer']['max_num_iterations'], diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 68dcce3f..cc182705 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,7 +1,7 @@ import numpy as np from pytorch3dunet.augment.transforms import RandomLabelToAffinities, LabelToAffinities, Transformer, Relabel, \ - CropToFixed + CropToFixed, RandomGammaCorrection class TestTransforms: @@ -214,6 +214,12 @@ def test_crop_to_fixed(self): assert np.array_equal(m_crop, t(m)) + def test_random_gamma_correction(self): + m = np.random.rand(200, 200, 200) + t = RandomGammaCorrection(np.random.RandomState(), gamma=(1.1, 2.0), execution_probability=1.0) + # Output is darker for gamma > 1 + assert np.mean(m) > np.mean(t(m)) + def _diagonal_label_volume(size, init=1): label = init * np.ones((size, size, size), dtype=np.int32)