From 5c1d0f663c19f6d2d5d0af6869178167890c2810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 26 Oct 2023 09:40:55 -0300 Subject: [PATCH] Docstrings for simulai.models.Transformer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../models/_pytorch_models/_transformer.py | 140 ++++++++++++++++-- 1 file changed, 130 insertions(+), 10 deletions(-) diff --git a/simulai/models/_pytorch_models/_transformer.py b/simulai/models/_pytorch_models/_transformer.py index 3aca7fec..890a67ca 100644 --- a/simulai/models/_pytorch_models/_transformer.py +++ b/simulai/models/_pytorch_models/_transformer.py @@ -9,10 +9,30 @@ class BaseTemplate(NetworkTemplate): def __init__(self): + """ + Template used for sharing fundamental methods with the + children transformer-like encoders and decoders. + """ super(BaseTemplate, self).__init__() def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.nn.Module: + """ + It configures the activation functions for the transformer layers. + + Parameters + ---------- + activation : Union[str, torch.nn.Module] + Activation function to be used in all the network layers + Returns + A Module object for this activation function. + ------- + + Raises + ------ + Exception : + When the activation function is not supported. + """ if isinstance(activation, torch.nn.Module): return encoder_activation @@ -23,11 +43,25 @@ def _activation_getter(self, activation: Union[str, torch.nn.Module]) -> torch.n class BasicEncoder(BaseTemplate): - def __init__(self, num_heads=1, + def __init__(self, num_heads:int=1, activation:Union[str, torch.nn.Module]='relu', mlp_layer:torch.nn.Module=None, embed_dim:Union[int, Tuple]=None, - ): + ) -> None: + """ + Generic transformer encoder. + + Parameters + ---------- + num_heads : int + Number of attention heads for the self-attention layers. + activation : Union[str, torch.nn.Module]= + Activation function to be used in all the network layers + mlp_layer : torch.nn.Module + A Module object representing the MLP (Dense) operation. + embed_dim : Union[int, Tuple] + Dimension used for the transfoirmer embedding. + """ super(BasicEncoder, self).__init__() @@ -53,6 +87,18 @@ def __init__(self, num_heads=1, def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None ) -> torch.Tensor: + """ + + Parameters + ---------- + input_data : Union[torch.Tensor, np.ndarray] + The input dataset. + + Returns + ------- + torch.Tensor + The output generated by the encoder. + """ h = input_data h1 = self.activation_1(h) @@ -68,6 +114,20 @@ def __init__(self, num_heads:int=1, activation:Union[str, torch.nn.Module]='relu', mlp_layer:torch.nn.Module=None, embed_dim:Union[int, Tuple]=None): + """ + Generic transformer decoder. + + Parameters + ---------- + num_heads : int + Number of attention heads for the self-attention layers. + activation : Union[str, torch.nn.Module]= + Activation function to be used in all the network layers + mlp_layer : torch.nn.Module + A Module object representing the MLP (Dense) operation. + embed_dim : Union[int, Tuple] + Dimension used for the transfoirmer embedding. + """ super(BasicDecoder, self).__init__() @@ -94,6 +154,20 @@ def __init__(self, num_heads:int=1, def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None, encoder_output:torch.Tensor=None, ) -> torch.Tensor: + """ + + Parameters + ---------- + input_data : Union[torch.Tensor, np.ndarray] + The input dataset (in principle, the same input used for the encoder). + encoder_output : torch.Tensor + The output provided by the encoder stage. + + Returns + ------- + torch.Tensor + The decoder output. + """ h = input_data h1 = self.activation_1(h) @@ -115,6 +189,37 @@ def __init__(self, num_heads_encoder:int=1, decoder_mlp_layer_config:dict=None, number_of_encoders:int=1, number_of_decoders:int=1) -> None: + """ + A classical encoder-decoder transformer: + + U -> ( Encoder_1 -> Encoder_2 -> ... -> Encoder_N ) -> u_e + + (u_e, U) -> ( Decoder_1 -> Decoder_2 -> ... Decoder_N ) -> V + + Parameters + ---------- + num_heads_encoder : int + The number of heads for the self-attention layer of the encoder. + num_heads_decoder :int + The number of heads for the self-attention layer of the decoder. + embed_dim_encoder : int + The dimension of the embedding for the encoder. + embed_dim_decoder : int + The dimension of the embedding for the decoder. + encoder_activation : Union[str, torch.nn.Module] + The activation to be used in all the encoder layers. + decoder_activation : Union[str, torch.nn.Module] + The activation to be used in all the decoder layers. + encoder_mlp_layer_config : dict + A configuration dictionary to instantiate the encoder MLP layer.weights + decoder_mlp_layer_config : dict + A configuration dictionary to instantiate the encoder MLP layer.weights + number_of_encoders : int + The number of encoders to be used. + number_of_decoders : int + The number of decoders to be used. + + """ super(Transformer, self).__init__() @@ -165,7 +270,6 @@ def __init__(self, num_heads_encoder:int=1, ] - self.weights = list() for e, encoder_e in enumerate(self.EncoderStage): @@ -179,15 +283,31 @@ def __init__(self, num_heads_encoder:int=1, @as_tensor def forward(self, input_data: Union[torch.Tensor, np.ndarray] = None) -> torch.Tensor: - encoder_output = self.EncoderStage(input_data) + """ + + Parameters + ---------- + input_data : Union[torch.Tensor, np.ndarray] + The input dataset. + + Returns + ------- + torch.Tensor + The transformer output. + """ + + encoder_output = self.EncoderStage(input_data) - current_input = input_data - for decoder in self.DecoderStage: - output = decoder(input_data=current_input, encoder_output=encoder_output) - current_input = output + current_input = input_data + for decoder in self.DecoderStage: + output = decoder(input_data=current_input, encoder_output=encoder_output) + current_input = output - return output + return output def summary(self): + """ + It prints a general view of the architecture. + """ - print(self) + print(self)