From f2db9d699d66db71ca9ed105af2178b68e7c9d38 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Thu, 4 Apr 2024 16:21:15 +0200 Subject: [PATCH] define activation function at the beginning --- flexynesis/modules.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flexynesis/modules.py b/flexynesis/modules.py index ac61e08..8f909be 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -19,19 +19,19 @@ class Encoder(nn.Module): def __init__(self, input_dim, hidden_dims, latent_dim): super(Encoder, self).__init__() - self.LeakyReLU = nn.LeakyReLU(0.2) + self.act = nn.LeakyReLU(0.2) hidden_layers = [] hidden_layers.append(nn.Linear(input_dim, hidden_dims[0])) nn.init.xavier_uniform_(hidden_layers[-1].weight) - hidden_layers.append(self.LeakyReLU) + hidden_layers.append(self.act) hidden_layers.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(len(hidden_dims)-1): hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) nn.init.xavier_uniform_(hidden_layers[-1].weight) - hidden_layers.append(self.LeakyReLU) + hidden_layers.append(self.act) hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1])) self.hidden_layers = nn.Sequential(*hidden_layers) @@ -68,19 +68,19 @@ class Decoder(nn.Module): def __init__(self, latent_dim, hidden_dims, output_dim): super(Decoder, self).__init__() - self.LeakyReLU = nn.LeakyReLU(0.2) + self.act = nn.LeakyReLU(0.2) hidden_layers = [] hidden_layers.append(nn.Linear(latent_dim, hidden_dims[0])) nn.init.xavier_uniform_(hidden_layers[-1].weight) - hidden_layers.append(self.LeakyReLU) + hidden_layers.append(self.act) hidden_layers.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(len(hidden_dims) - 1): hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1])) nn.init.xavier_uniform_(hidden_layers[-1].weight) - hidden_layers.append(self.LeakyReLU) + hidden_layers.append(self.act) hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1])) self.hidden_layers = nn.Sequential(*hidden_layers) @@ -99,7 +99,7 @@ def forward(self, x): x_hat (torch.Tensor): The reconstructed output tensor. """ h = self.hidden_layers(x) - x_hat = torch.tanh(self.FC_output(h)) + x_hat = torch.sigmoid(self.FC_output(h)) return x_hat