From 221961b077eb1fcdef70691da503feb187f688c0 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 3 Apr 2024 12:06:58 +0200 Subject: [PATCH] improve decoding function; decode all output layers; convert to data frame and add row/col names --- flexynesis/models/crossmodal_pred.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 1825f4e..9e8691c 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -83,7 +83,6 @@ def __init__(self, config, dataset, target_variables = None, batch_variables = [int(output_dims[i] * config['hidden_dim_factor'])], output_dims[i]) for i in range(len(self.output_layers))]) - print(self.decoders) # define supervisor heads # using ModuleDict to store multiple MLPs @@ -342,11 +341,18 @@ def decode(self, dataset): """ Extract the decoded values of the target/output layers """ - self.eval() - x_list = [dataset.dat[x] for x in self.input_layers] - X_hat, z, mean, log_var, outputs = self.forward(x_list) - return X_hat + x_list_input = [dataset.dat[x] for x in self.input_layers] + x_list_output = [dataset.dat[x] for x in self.output_layers] + x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output) + X = {} + for i in range(len(self.output_layers)): + x = pd.DataFrame(x_hat_list[i].detach().numpy()).transpose() + layer = self.output_layers[i] + x.columns = dataset.samples + x.index = dataset.features[layer] + X[layer] = x + return X def compute_kernel(self, x, y):