Skip to content

Commit

Permalink
Add debug prints to training test and roll back to 10% ATOL
Browse files Browse the repository at this point in the history
  • Loading branch information
kmabeeTT committed Jan 13, 2025
1 parent 0d880a3 commit 7367bb9
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,16 @@ def test_mnist_training_with_grad_accumulation():
total_loss += loss.item()

golden_loss = loss_fn(golden_pred, target)
assert torch.allclose(loss, golden_loss, rtol=5e-2) # 5% tolerance

print(f"Loss: {loss}")
print(f"Golden Loss: {golden_loss}")

diff = torch.abs(loss - golden_loss)
relative_diff = diff / torch.abs(golden_loss)
print(f"Absolute Difference: {diff}")
print(f"Relative Difference: {relative_diff}")

assert torch.allclose(loss, golden_loss, rtol=1e-1) # 10% tolerance

# Run backward pass on device
loss.backward()
Expand Down

0 comments on commit 7367bb9

Please sign in to comment.