Skip to content

Commit

Permalink
BUG Fix: Adding norm before generating the prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Mouradost committed Aug 14, 2023
1 parent 2bdbb46 commit e86f364
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def get_predictions(
x.to(self.device),
)
prediction = self.scaler.reverse(
self.model(x=x, adj=adj, adj_hat=sim, idx=idx).cpu().detach()
self.model(x=self.scaler.norm(x), adj=adj, adj_hat=sim, idx=idx)
.cpu()
.detach()
)
all_grounds.append(y.detach())
all_predictions.append(prediction)
Expand Down
7 changes: 5 additions & 2 deletions src/trainers/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def loss(
if self.model_pred_single:
y = y[:, -1:]

contrib = self.e_model(x=x, idx=idx)
contrib = self.e_model(x=self.scaler.norm(x), idx=idx)
e_loss = self.loss_fn(contrib, sim, dim=(0, 1, 2))
loss = self.loss_fn(
pred=self.scaler.reverse(
Expand Down Expand Up @@ -218,7 +218,10 @@ def get_predictions(
)
prediction = self.scaler.reverse(
self.model(
x=x, adj=adj, adj_hat=self.e_model(x=x, idx=idx), idx=idx
x=self.scaler.norm(x),
adj=adj,
adj_hat=self.e_model(x=self.scaler.norm(x), idx=idx),
idx=idx,
)
.cpu()
.detach()
Expand Down

0 comments on commit e86f364

Please sign in to comment.