From 941e739c933691a2da2d111d3a693f47d6330939 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Wed, 13 Nov 2024 10:44:17 -0500 Subject: [PATCH] Add MedNext implementation (#8004) Fixes #7786 ### Description Added MedNext architectures implementation for MONAI. Since a lot of the code is heavily sourced from the original MedNext repo, https://github.com/MIC-DKFZ/MedNeXt, I wanted to check if there is an attribution policy with regarded to borrowed source code. I've added a derivative notice bellow the monai copyright comment. Let me know if this needs to be changed. The blocks have been taken almost as is but the network implementation has been changed largely to allow flexible blocks and follow MONAI segresnet styling. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Suraj Pai Signed-off-by: Robin CREMESE Co-authored-by: Robin CREMESE Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/mednext_block.py | 309 +++++++++++++++++++++ monai/networks/nets/__init__.py | 19 ++ monai/networks/nets/mednext.py | 354 +++++++++++++++++++++++++ tests/test_mednext.py | 122 +++++++++ 5 files changed, 805 insertions(+) create mode 100644 monai/networks/blocks/mednext_block.py create mode 100644 monai/networks/nets/mednext.py create mode 100644 tests/test_mednext.py diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 47abc4a1c4..499caf2e0f 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -26,6 +26,7 @@ from .fcn import FCN, GCN, MCFCN, Refine from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7 from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock +from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock from .mlp import MLPBlock from .patchembedding import PatchEmbed, PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py new file mode 100644 index 0000000000..0aa2bb6b58 --- /dev/null +++ b/monai/networks/blocks/mednext_block.py @@ -0,0 +1,309 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this code are derived from the original repository at: +# https://github.com/MIC-DKFZ/MedNeXt +# and are used under the terms of the Apache License, Version 2.0. + +from __future__ import annotations + +import torch +import torch.nn as nn + +all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"] + + +def get_conv_layer(spatial_dim: int = 3, transpose: bool = False): + if spatial_dim == 2: + return nn.ConvTranspose2d if transpose else nn.Conv2d + else: # spatial_dim == 3 + return nn.ConvTranspose3d if transpose else nn.Conv3d + + +class MedNeXtBlock(nn.Module): + """ + MedNeXtBlock class for the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (int): Whether to use residual connection. Defaults to True. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: int = True, + norm_type: str = "group", + dim="3d", + global_resp_norm=False, + ): + + super().__init__() + + self.do_res = use_residual_connection + + self.dim = dim + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) + global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3) + # First convolution layer with DepthWise Convolutions + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=in_channels, + ) + + # Normalization Layer. GroupNorm is used by default. + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore + elif norm_type == "layer": + self.norm = nn.LayerNorm( + normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore + ) + # Second convolution (Expansion) layer with Conv3D 1x1x1 + self.conv2 = conv( + in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0 + ) + + # GeLU activations + self.act = nn.GELU() + + # Third convolution (Compression) layer with Conv3D 1x1x1 + self.conv3 = conv( + in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + self.global_resp_norm = global_resp_norm + if self.global_resp_norm: + global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape + self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True) + self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True) + + def forward(self, x): + """ + Forward pass of the MedNeXtBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x1 = x + x1 = self.conv1(x1) + x1 = self.act(self.conv2(self.norm(x1))) + + if self.global_resp_norm: + # gamma, beta: learnable affine transform parameters + # X: input of shape (N,C,H,W,D) + if self.dim == "2d": + gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True) + else: + gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6) + x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1 + x1 = self.conv3(x1) + if self.do_res: + x1 = x + x1 + return x1 + + +class MedNeXtDownBlock(MedNeXtBlock): + """ + MedNeXtDownBlock class for downsampling in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (bool): Whether to use residual connection. Defaults to False. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + global_resp_norm: bool = False, + ): + + super().__init__( + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + global_resp_norm=global_resp_norm, + ) + + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) + self.resample_do_res = use_residual_connection + if use_residual_connection: + self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) + + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x): + """ + Forward pass of the MedNeXtDownBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x1 = super().forward(x) + + if self.resample_do_res: + res = self.res_conv(x) + x1 = x1 + res + + return x1 + + +class MedNeXtUpBlock(MedNeXtBlock): + """ + MedNeXtUpBlock class for upsampling in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (int): Expansion ratio for the block. Defaults to 4. + kernel_size (int): Kernel size for convolutions. Defaults to 7. + use_residual_connection (bool): Whether to use residual connection. Defaults to False. + norm_type (str): Type of normalization to use. Defaults to "group". + dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d". + global_resp_norm (bool): Whether to use global response normalization. Defaults to False. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: int = 4, + kernel_size: int = 7, + use_residual_connection: bool = False, + norm_type: str = "group", + dim: str = "3d", + global_resp_norm: bool = False, + ): + super().__init__( + in_channels, + out_channels, + expansion_ratio, + kernel_size, + use_residual_connection=False, + norm_type=norm_type, + dim=dim, + global_resp_norm=global_resp_norm, + ) + + self.resample_do_res = use_residual_connection + + self.dim = dim + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) + if use_residual_connection: + self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) + + self.conv1 = conv( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x): + """ + Forward pass of the MedNeXtUpBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x1 = super().forward(x) + # Asymmetry but necessary to match shape + + if self.dim == "2d": + x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0)) + else: + x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0)) + + if self.resample_do_res: + res = self.res_conv(x) + if self.dim == "2d": + res = torch.nn.functional.pad(res, (1, 0, 1, 0)) + else: + res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0)) + x1 = x1 + res + + return x1 + + +class MedNeXtOutBlock(nn.Module): + """ + MedNeXtOutBlock class for the output block in the MedNeXt model. + + Args: + in_channels (int): Number of input channels. + n_classes (int): Number of output classes. + dim (str): Dimension of the input. Can be "2d" or "3d". + """ + + def __init__(self, in_channels, n_classes, dim): + super().__init__() + + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) + self.conv_out = conv(in_channels, n_classes, kernel_size=1) + + def forward(self, x): + """ + Forward pass of the MedNeXtOutBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + return self.conv_out(x) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 0570c9fcc1..b876e6a3fc 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,25 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .mednext import ( + MedNeXt, + MedNext, + MedNextB, + MedNeXtB, + MedNextBase, + MedNextL, + MedNeXtL, + MedNeXtLarge, + MedNextLarge, + MedNextM, + MedNeXtM, + MedNeXtMedium, + MedNextMedium, + MedNextS, + MedNeXtS, + MedNeXtSmall, + MedNextSmall, +) from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py new file mode 100644 index 0000000000..427572ba60 --- /dev/null +++ b/monai/networks/nets/mednext.py @@ -0,0 +1,354 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this code are derived from the original repository at: +# https://github.com/MIC-DKFZ/MedNeXt +# and are used under the terms of the Apache License, Version 2.0. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn as nn + +from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock + +__all__ = [ + "MedNeXt", + "MedNeXtSmall", + "MedNeXtBase", + "MedNeXtMedium", + "MedNeXtLarge", + "MedNext", + "MedNextS", + "MedNeXtS", + "MedNextSmall", + "MedNextB", + "MedNeXtB", + "MedNextBase", + "MedNextM", + "MedNeXtM", + "MedNextMedium", + "MedNextL", + "MedNeXtL", + "MedNextLarge", +] + + +class MedNeXt(nn.Module): + """ + MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975 + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2. + decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2. + bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2. + kernel_size: kernel size for convolutions. Defaults to 7. + deep_supervision: whether to use deep supervision. Defaults to False. + use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False. + blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2]. + blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2. + blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2]. + norm_type: type of normalization layer. Defaults to 'group'. + global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808 + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + encoder_expansion_ratio: Sequence[int] | int = 2, + decoder_expansion_ratio: Sequence[int] | int = 2, + bottleneck_expansion_ratio: int = 2, + kernel_size: int = 7, + deep_supervision: bool = False, + use_residual_connection: bool = False, + blocks_down: Sequence[int] = (2, 2, 2, 2), + blocks_bottleneck: int = 2, + blocks_up: Sequence[int] = (2, 2, 2, 2), + norm_type: str = "group", + global_resp_norm: bool = False, + ): + """ + Initialize the MedNeXt model. + + This method sets up the architecture of the model, including: + - Stem convolution + - Encoder stages and downsampling blocks + - Bottleneck blocks + - Decoder stages and upsampling blocks + - Output blocks for deep supervision (if enabled) + """ + super().__init__() + + self.do_ds = deep_supervision + assert spatial_dims in [2, 3], "`spatial_dims` can only be 2 or 3." + spatial_dims_str = f"{spatial_dims}d" + enc_kernel_size = dec_kernel_size = kernel_size + + if isinstance(encoder_expansion_ratio, int): + encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down) + + if isinstance(decoder_expansion_ratio, int): + decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up) + + conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d + + self.stem = conv(in_channels, init_filters, kernel_size=1) + + enc_stages = [] + down_blocks = [] + + for i, num_blocks in enumerate(blocks_down): + enc_stages.append( + nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2**i), + out_channels=init_filters * (2**i), + expansion_ratio=encoder_expansion_ratio[i], + kernel_size=enc_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + for _ in range(num_blocks) + ] + ) + ) + + down_blocks.append( + MedNeXtDownBlock( + in_channels=init_filters * (2**i), + out_channels=init_filters * (2 ** (i + 1)), + expansion_ratio=encoder_expansion_ratio[i], + kernel_size=enc_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + ) + ) + + self.enc_stages = nn.ModuleList(enc_stages) + self.down_blocks = nn.ModuleList(down_blocks) + + self.bottleneck = nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2 ** len(blocks_down)), + out_channels=init_filters * (2 ** len(blocks_down)), + expansion_ratio=bottleneck_expansion_ratio, + kernel_size=dec_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + for _ in range(blocks_bottleneck) + ] + ) + + up_blocks = [] + dec_stages = [] + for i, num_blocks in enumerate(blocks_up): + up_blocks.append( + MedNeXtUpBlock( + in_channels=init_filters * (2 ** (len(blocks_up) - i)), + out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + expansion_ratio=decoder_expansion_ratio[i], + kernel_size=dec_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + ) + + dec_stages.append( + nn.Sequential( + *[ + MedNeXtBlock( + in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)), + expansion_ratio=decoder_expansion_ratio[i], + kernel_size=dec_kernel_size, + use_residual_connection=use_residual_connection, + norm_type=norm_type, + dim=spatial_dims_str, + global_resp_norm=global_resp_norm, + ) + for _ in range(num_blocks) + ] + ) + ) + + self.up_blocks = nn.ModuleList(up_blocks) + self.dec_stages = nn.ModuleList(dec_stages) + + self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str) + + if deep_supervision: + out_blocks = [ + MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str) + for i in range(1, len(blocks_up) + 1) + ] + + out_blocks.reverse() + self.out_blocks = nn.ModuleList(out_blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]: + """ + Forward pass of the MedNeXt model. + + This method performs the forward pass through the model, including: + - Stem convolution + - Encoder stages and downsampling + - Bottleneck blocks + - Decoder stages and upsampling with skip connections + - Output blocks for deep supervision (if enabled) + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor or Sequence[torch.Tensor]: Output tensor(s). + """ + # Apply stem convolution + x = self.stem(x) + + # Encoder forward pass + enc_outputs = [] + for enc_stage, down_block in zip(self.enc_stages, self.down_blocks): + x = enc_stage(x) + enc_outputs.append(x) + x = down_block(x) + + # Bottleneck forward pass + x = self.bottleneck(x) + + # Initialize deep supervision outputs if enabled + if self.do_ds: + ds_outputs = [] + + # Decoder forward pass with skip connections + for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)): + if self.do_ds and i < len(self.out_blocks): + ds_outputs.append(self.out_blocks[i](x)) + + x = up_block(x) + x = x + enc_outputs[-(i + 1)] + x = dec_stage(x) + + # Final output block + x = self.out_0(x) + + # Return output(s) + if self.do_ds and self.training: + return (x, *ds_outputs[::-1]) + else: + return x + + +# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975 +def create_mednext( + variant: str, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 2, + kernel_size: int = 3, + deep_supervision: bool = False, +) -> MedNeXt: + """ + Factory method to create MedNeXt variants. + + Args: + variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L'). + spatial_dims (int): Number of spatial dimensions. Defaults to 3. + in_channels (int): Number of input channels. Defaults to 1. + out_channels (int): Number of output channels. Defaults to 2. + kernel_size (int): Kernel size for convolutions. Defaults to 3. + deep_supervision (bool): Whether to use deep supervision. Defaults to False. + + Returns: + MedNeXt: The specified MedNeXt variant. + + Raises: + ValueError: If an invalid variant is specified. + """ + common_args = { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "deep_supervision": deep_supervision, + "use_residual_connection": True, + "norm_type": "group", + "global_resp_norm": False, + "init_filters": 32, + } + + if variant.upper() == "S": + return MedNeXt( + encoder_expansion_ratio=2, + decoder_expansion_ratio=2, + bottleneck_expansion_ratio=2, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + **common_args, # type: ignore + ) + elif variant.upper() == "B": + return MedNeXt( + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + blocks_down=(2, 2, 2, 2), + blocks_bottleneck=2, + blocks_up=(2, 2, 2, 2), + **common_args, # type: ignore + ) + elif variant.upper() == "M": + return MedNeXt( + encoder_expansion_ratio=(2, 3, 4, 4), + decoder_expansion_ratio=(4, 4, 3, 2), + bottleneck_expansion_ratio=4, + blocks_down=(3, 4, 4, 4), + blocks_bottleneck=4, + blocks_up=(4, 4, 4, 3), + **common_args, # type: ignore + ) + elif variant.upper() == "L": + return MedNeXt( + encoder_expansion_ratio=(3, 4, 8, 8), + decoder_expansion_ratio=(8, 8, 4, 3), + bottleneck_expansion_ratio=8, + blocks_down=(3, 4, 8, 8), + blocks_bottleneck=8, + blocks_up=(8, 8, 4, 3), + **common_args, # type: ignore + ) + else: + raise ValueError(f"Invalid MedNeXt variant: {variant}") + + +MedNext = MedNeXt +MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs) +MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs) +MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs) +MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs) diff --git a/tests/test_mednext.py b/tests/test_mednext.py new file mode 100644 index 0000000000..b4ba4f9939 --- /dev/null +++ b/tests/test_mednext.py @@ -0,0 +1,122 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_MEDNEXT = [] +for spatial_dims in range(2, 4): + for init_filters in [8, 16]: + for deep_supervision in [False, True]: + for do_res in [False, True]: + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": init_filters, + "deep_supervision": deep_supervision, + "use_residual_connection": do_res, + }, + (2, 1, *([16] * spatial_dims)), + (2, 2, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT.append(test_case) + +TEST_CASE_MEDNEXT_2 = [] +for spatial_dims in range(2, 4): + for out_channels in [1, 2]: + for deep_supervision in [False, True]: + test_case = [ + { + "spatial_dims": spatial_dims, + "init_filters": 8, + "out_channels": out_channels, + "deep_supervision": deep_supervision, + }, + (2, 1, *([16] * spatial_dims)), + (2, out_channels, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT_2.append(test_case) + +TEST_CASE_MEDNEXT_VARIANTS = [] +for model in [MedNeXtS, MedNeXtM, MedNeXtL]: + for spatial_dims in range(2, 4): + for out_channels in [1, 2]: + test_case = [ + model, # type: ignore + {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels}, + (2, 1, *([16] * spatial_dims)), + (2, out_channels, *([16] * spatial_dims)), + ] + TEST_CASE_MEDNEXT_VARIANTS.append(test_case) + + +class TestMedNeXt(unittest.TestCase): + + @parameterized.expand(TEST_CASE_MEDNEXT) + def test_shape(self, input_param, input_shape, expected_shape): + net = MedNeXt(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + if input_param["deep_supervision"] and net.training: + assert isinstance(result, tuple) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + else: + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + @parameterized.expand(TEST_CASE_MEDNEXT_2) + def test_shape2(self, input_param, input_shape, expected_shape): + net = MedNeXt(**input_param).to(device) + + net.train() + result = net(torch.randn(input_shape).to(device)) + if input_param["deep_supervision"]: + assert isinstance(result, tuple) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + MedNeXt(spatial_dims=4) + + @parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS) + def test_mednext_variants(self, model, input_param, input_shape, expected_shape): + net = model(**input_param).to(device) + + net.train() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + net.eval() + with torch.no_grad(): + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + +if __name__ == "__main__": + unittest.main()