From e86f36409b1056cef25cda2c84d896cc4854ec49 Mon Sep 17 00:00:00 2001 From: Lablack Mourad Date: Mon, 14 Aug 2023 13:50:12 +0800 Subject: [PATCH] BUG Fix: Adding norm before generating the prediction --- src/trainers/base.py | 4 +++- src/trainers/full.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/trainers/base.py b/src/trainers/base.py index 6545ec4..e860ac6 100644 --- a/src/trainers/base.py +++ b/src/trainers/base.py @@ -131,7 +131,9 @@ def get_predictions( x.to(self.device), ) prediction = self.scaler.reverse( - self.model(x=x, adj=adj, adj_hat=sim, idx=idx).cpu().detach() + self.model(x=self.scaler.norm(x), adj=adj, adj_hat=sim, idx=idx) + .cpu() + .detach() ) all_grounds.append(y.detach()) all_predictions.append(prediction) diff --git a/src/trainers/full.py b/src/trainers/full.py index bfd1b3b..f8684a8 100644 --- a/src/trainers/full.py +++ b/src/trainers/full.py @@ -187,7 +187,7 @@ def loss( if self.model_pred_single: y = y[:, -1:] - contrib = self.e_model(x=x, idx=idx) + contrib = self.e_model(x=self.scaler.norm(x), idx=idx) e_loss = self.loss_fn(contrib, sim, dim=(0, 1, 2)) loss = self.loss_fn( pred=self.scaler.reverse( @@ -218,7 +218,10 @@ def get_predictions( ) prediction = self.scaler.reverse( self.model( - x=x, adj=adj, adj_hat=self.e_model(x=x, idx=idx), idx=idx + x=self.scaler.norm(x), + adj=adj, + adj_hat=self.e_model(x=self.scaler.norm(x), idx=idx), + idx=idx, ) .cpu() .detach()