Skip to content

Commit

Permalink
Merge pull request #1226 from AI-Hypercomputer:patemotter_tol_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721981963
  • Loading branch information
maxtext authors committed Feb 1, 2025
2 parents d33821f + 3f97cca commit 232060f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions MaxText/tests/kernels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_ragged_mqa(self):
ragged_out, ragged_max, ragged_denom = ragged_mqa(q, k, v, lengths)
reference_out, reference_max, reference_denom = reference_mqa(q, k, v, lengths)
self.assertTrue(
jnp.max(abs(ragged_out - reference_out)) < 1e-1,
jnp.max(abs(ragged_out - reference_out)) < 1.5e-1,
msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1",
)
self.assertTrue(
Expand All @@ -71,7 +71,7 @@ def test_ragged_mha(self):
ragged_out = ragged_out / ragged_denom
reference_out, reference_max, reference_denom = reference_mha(q, k, v, lengths)
self.assertTrue(
jnp.max(abs(ragged_out - reference_out)) < 1e-1,
jnp.max(abs(ragged_out - reference_out)) < 1.5e-1,
msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1",
)
self.assertTrue(
Expand All @@ -96,7 +96,7 @@ def test_ragged_gqa(self):
jnp.squeeze(q), jnp.swapaxes(k, 1, 2), jnp.swapaxes(v, 1, 2), lengths
)
self.assertTrue(
jnp.max(abs(ragged_out - reference_out)) < 1e-1,
jnp.max(abs(ragged_out - reference_out)) < 1.5e-1,
msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1",
)
self.assertTrue(
Expand Down

0 comments on commit 232060f

Please sign in to comment.