diff --git a/tests/test_cython_methods.py b/tests/test_cython_methods.py index a2135c2..0aedcb6 100644 --- a/tests/test_cython_methods.py +++ b/tests/test_cython_methods.py @@ -6,6 +6,7 @@ class TestCythonMethods(): @pytest.mark.parametrize("prec", [np.float32, np.float64]) def test_2D_vec(self, prec): + rtol = (1e-3 if prec == np.float32 else 1e-4) A = np.random.normal(size=(100, 2, 2)).astype(prec) A[..., 1, 0] = A[..., 0, 1] #Symmetrize @@ -14,7 +15,7 @@ def test_2D_vec(self, prec): expected_result = np.einsum('...x,...xy,...y->...', z1, A, z2) comp_result = metric_norm_matrix_2D_cython(A, z1, z2, ret_sqrt=False) - assert np.allclose(expected_result, comp_result, rtol=1e-4) #Single vectorized + assert np.allclose(expected_result, comp_result, rtol=rtol) #Broadcasted A = np.random.normal(size=(100, 1, 2, 2)) @@ -25,10 +26,11 @@ def test_2D_vec(self, prec): expected_result = np.einsum('...x,...xy,...y->...', z1, A, z2) comp_result = metric_norm_matrix_2D_cython(A, z1, z2, ret_sqrt=False) - assert np.allclose(expected_result, comp_result, rtol=1e-4) #Single vectorized + assert np.allclose(expected_result, comp_result, rtol=rtol) #Broadcasted @pytest.mark.parametrize("prec", [np.float32, np.float64]) def test_3D_vec(self, prec): + rtol = (1e-3 if prec == np.float32 else 1e-4) A = np.random.normal(size=(100, 3, 3)).astype(prec) A[..., 1, 0] = A[..., 0, 1] #Symmetrize A[..., 2, 0] = A[..., 0, 2] @@ -39,7 +41,7 @@ def test_3D_vec(self, prec): expected_result = np.einsum('...x,...xy,...y->...', z1, A, z2) comp_result = metric_norm_matrix_3D_cython(A, z1, z2, ret_sqrt=False) - assert np.allclose(expected_result, comp_result, rtol=1e-4) #Single vectorized + assert np.allclose(expected_result, comp_result, rtol=rtol) #Single vectorized #Broadcasted A = np.random.normal(size=(100, 1, 3, 3)) @@ -52,6 +54,6 @@ def test_3D_vec(self, prec): expected_result = np.einsum('...x,...xy,...y->...', z1, A, z2) comp_result = metric_norm_matrix_3D_cython(A, z1, z2, ret_sqrt=False) - assert np.allclose(expected_result, comp_result, rtol=1e-4) #Single vectorized + assert np.allclose(expected_result, comp_result, rtol=rtol)