Skip to content

Commit

Permalink
Docstrings for U-Net
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>

decorators for U-Net modules

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>

ignore this block

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>

removing decorators for UNet.forward for now

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
João Lucas de Sousa Almeida authored and João Lucas de Sousa Almeida committed Oct 25, 2023
1 parent 181d133 commit a33fe17
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 12 deletions.
137 changes: 126 additions & 11 deletions simulai/models/_pytorch_models/_unet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import numpy as np
import torch
from typing import Union, List, Tuple, Optional
from typing import Union, List, Tuple, Optional, Dict

from simulai.templates import NetworkTemplate, as_tensor, channels_dim
from simulai.regression import DenseNetwork, SLFNN, ConvolutionalNetwork
Expand All @@ -14,8 +14,8 @@ class CNNUnetEncoder(ConvolutionalNetwork):

def __init__(
self,
layers: list = None,
activations: list = None,
layers: List[Dict] = None,
activations: Union[str, List[str]] = None,
pre_layer: Optional[torch.nn.Module] = None,
case: str = "2d",
last_activation: str = "identity",
Expand All @@ -24,6 +24,32 @@ def __init__(
intermediary_outputs_indices: List[int] = None,
name: str = None,
) -> None:
"""
A CNN encoder for U-Nets.
Parameters
----------
layers : List[Dict]
A list of configurations dictionaries for instantiating the layers.
activations :
A string or a list of strings defining the kind of activation to be used.
pre_layer : Optional[torch.nn.Module]
A layer for pre-processing the input.
case : str
The kind of CNN to be used, in ["1d", "2d", "3d"].
last_activation : str
The kind of activation to be used after the last layer.
transpose : bool
Using transposed convolution or not.
flatten : bool
Flattening the output or not.
intermediary_outputs_indices : List[int],
A list of indices for indicating what are the encoder outputs, which will be
subsequently inputted in the decoder stage.
name : str
A name for the model.
"""

super(CNNUnetEncoder, self).__init__(layers=layers,
activations=activations,
Expand All @@ -43,7 +69,22 @@ def __init__(
@channels_dim
def forward(
self, input_data: Union[torch.Tensor, np.ndarray] = None
) -> torch.Tensor:
) -> [torch.Tensor, List[torch.Tensor]]:
"""
The CNN U-Net encoder forward method.
Parameters
----------
input_data : Union[torch.Tensor, np.ndarray],
A dataset to be inputted in the CNN U-Net encoder.
Returns
-------
[torch.Tensor, List[torch.Tensor]]
A list containing the main encoder output (latent space) and
another list of outputs, corresponding to the intermediary encoder
outputs.
"""

intermediary_outputs = list()

Expand Down Expand Up @@ -71,6 +112,32 @@ def __init__(
name: str = None,
channels_last=False,
) -> None:
"""
A CNN decoder for U-Nets.
Parameters
----------
layers : List[Dict]
A list of configurations dictionaries for instantiating the layers.
activations :
A string or a list of strings defining the kind of activation to be used.
pre_layer : Optional[torch.nn.Module]
A layer for pre-processing the input.
case : str
The kind of CNN to be used, in ["1d", "2d", "3d"].
last_activation : str
The kind of activation to be used after the last layer.
transpose : bool
Using transposed convolution or not.
flatten : bool
Flattening the output or not.
intermediary_inputs_indices : List[int],
A list of indices for indicating what are the decoder outputs.
name : str
A name for the model.
"""


super(CNNUnetDecoder, self).__init__(layers=layers,
activations=activations,
Expand All @@ -93,12 +160,26 @@ def __init__(
if not isinstance(layer_j, torch.nn.Identity)]
self.pipeline = torch.nn.Sequential(*self.list_of_layers)

#@as_tensor
#@channels_dim
def forward(
self, input_data: Union[torch.Tensor, np.ndarray] = None,
intermediary_encoder_outputs:List[torch.Tensor] = None,
) -> torch.Tensor:
"""
The CNN U-Net decoder forward method.
Parameters
----------
input_data : Union[torch.Tensor, np.ndarray],
A dataset to be inputted in the CNN U-Net decoder.
intermediary_encoder_outputs : List[torch.Tensor]
A list of tensors, corresponding to the intermediary encoder outputs.
Returns
-------
torch.Tensor
The decoder (and U-Net) output.
"""

current_input = input_data
intermediary_encoder_outputs = intermediary_encoder_outputs[::-1]
Expand All @@ -120,11 +201,29 @@ def forward(
class UNet(NetworkTemplate):


def __init__(self, layers_config:dict=None,
intermediary_outputs_indices:List=None,
intermediary_inputs_indices:List=None,
encoder_extra_args:dict=dict(),
decoder_extra_args:dict=dict()) -> None:
def __init__(self, layers_config:Dict=None,
intermediary_outputs_indices:List[int]=None,
intermediary_inputs_indices:List[int]=None,
encoder_extra_args:Dict=dict(),
decoder_extra_args:Dict=dict()) -> None:
"""
U-Net.
Parameters
----------
layers_config : Dict
A dictionary containing the complete configuration for the
U-Net encoder and decoder.
intermediary_outputs_indices : List[int]
A list of indices for indicating the encoder outputs.
intermediary_inputs_indices : List[int]
A list of indices for indicating the decoder inputs.
encoder_extra_args : Dict
A dictionary containing extra arguments for the encoder.
decoder_extra_args : Dict
A dictionary containing extra arguments for the decoder.
"""

super(UNet, self).__init__()

Expand Down Expand Up @@ -172,6 +271,19 @@ def __init__(self, layers_config:dict=None,
@as_tensor
def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None
) -> torch.Tensor:
"""
The U-Net forward method.
Parameters
----------
input_data : Union[torch.Tensor, np.ndarray],
A dataset to be inputted in the CNN U-Net encoder.
Returns
-------
torch.Tensor
The U-Net output.
"""

encoder_main_output, encoder_intermediary_outputs = self.encoder(input_data=input_data)
output = self.decoder(input_data = encoder_main_output,
Expand All @@ -180,5 +292,8 @@ def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None
return output

def summary(self):
"""
It shows a general view of the architecture.
"""

print(self)
2 changes: 1 addition & 1 deletion simulai/templates/_pytorch_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _setup_activations(
for activation_name in activation:
activation_op = self._get_operation(operation=activation_name, is_activation=True)

activations_list.append(activation_op)##activation_op())
activations_list.append(activation_op)

return activations_list, activation

Expand Down

0 comments on commit a33fe17

Please sign in to comment.