diff --git a/MaxText/tests/kernels_test.py b/MaxText/tests/kernels_test.py index 6313aa884..3c73ca10d 100644 --- a/MaxText/tests/kernels_test.py +++ b/MaxText/tests/kernels_test.py @@ -48,7 +48,7 @@ def test_ragged_mqa(self): ragged_out, ragged_max, ragged_denom = ragged_mqa(q, k, v, lengths) reference_out, reference_max, reference_denom = reference_mqa(q, k, v, lengths) self.assertTrue( - jnp.max(abs(ragged_out - reference_out)) < 1e-1, + jnp.max(abs(ragged_out - reference_out)) < 1.5e-1, msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1", ) self.assertTrue( @@ -71,7 +71,7 @@ def test_ragged_mha(self): ragged_out = ragged_out / ragged_denom reference_out, reference_max, reference_denom = reference_mha(q, k, v, lengths) self.assertTrue( - jnp.max(abs(ragged_out - reference_out)) < 1e-1, + jnp.max(abs(ragged_out - reference_out)) < 1.5e-1, msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1", ) self.assertTrue( @@ -96,7 +96,7 @@ def test_ragged_gqa(self): jnp.squeeze(q), jnp.swapaxes(k, 1, 2), jnp.swapaxes(v, 1, 2), lengths ) self.assertTrue( - jnp.max(abs(ragged_out - reference_out)) < 1e-1, + jnp.max(abs(ragged_out - reference_out)) < 1.5e-1, msg=f"Max difference: {jnp.max(abs(ragged_out - reference_out))} > 1e-1", ) self.assertTrue(