From 6269357d36fc44b233a613ca103ae4dc2f3b1aeb Mon Sep 17 00:00:00 2001 From: Jinfeng Li Date: Tue, 2 Jan 2024 14:00:02 -0800 Subject: [PATCH] fix logistic regression ci failure when spark 3.3.0 is used (#537) * fix logistic regression ci failure when spark 3.3.0 is used * clean * fix bug per review comment --------- Signed-off-by: Jinfeng --- python/tests/test_logistic_regression.py | 29 +++++++++++++++++------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index a581378d..26208732 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -1132,13 +1132,17 @@ def test_compat_one_label( ) if label < 0: - msg = f"Labels MUST be in [0, 2147483647), but got {label}" + spark_v34_msg = f"Labels MUST be in [0, 2147483647), but got {label}" + spark_v33_msg = ( + f"Classification labels should be in [0 to -1]. Found 4 invalid labels." + ) try: blor_model = blor.fit(bdf) assert False, "There should be a java exception" except Py4JJavaError as e: - assert msg in e.java_exception.getMessage() + java_msg = e.java_exception.getMessage() + assert spark_v34_msg in java_msg or spark_v33_msg in java_msg return @@ -1225,7 +1229,9 @@ def test_compat_wrong_label( feature_cols = ["c0", "c1"] schema = ["c0 float, c1 float, label float"] - def test_functor(y: np.ndarray, err_msg: str) -> None: + def test_functor( + y: np.ndarray, err_msg_spark_v34: str, err_msg_spark_v33: str + ) -> None: with CleanSparkSession() as spark: np_array = np.concatenate((X, y.reshape(num_rows, 1)), axis=1) @@ -1244,16 +1250,23 @@ def test_functor(y: np.ndarray, err_msg: str) -> None: lr.fit(df) assert False, "There should be a java exception" except Py4JJavaError as e: - assert err_msg in e.java_exception.getMessage() + java_msg = e.java_exception.getMessage() + assert err_msg_spark_v34 in java_msg or err_msg_spark_v33 in java_msg # negative label wrong_label = -1.1 y = np.array([1.0, 0.0, wrong_label, 2.0]) - msg = f"Labels MUST be in [0, 2147483647), but got {wrong_label}" - test_functor(y, msg) + spark_v34_msg = f"Labels MUST be in [0, 2147483647), but got {wrong_label}" + spark_v33_msg = ( + f"Classification labels should be in [0 to 2]. Found 1 invalid labels." + ) + test_functor(y, spark_v34_msg, spark_v33_msg) # non-integer label wrong_label = 0.4 y = np.array([1.0, 0.0, wrong_label, 2.0]) - msg = f"Labels MUST be Integers, but got {wrong_label}" - test_functor(y, msg) + spark_v34_msg = f"Labels MUST be Integers, but got {wrong_label}" + spark_v33_msg = ( + f"Classification labels should be in [0 to 2]. Found 1 invalid labels." + ) + test_functor(y, spark_v34_msg, spark_v33_msg)