From 3bb0d83a32ade29d972766036c422884d5c8a0bf Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Tue, 14 Nov 2023 00:54:38 -0800 Subject: [PATCH 1/7] support one_label Signed-off-by: Jinfeng --- python/src/spark_rapids_ml/classification.py | 16 +++++ python/tests/test_logistic_regression.py | 68 ++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index 878d6256..a1c1af57 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -976,6 +976,12 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: "dtype": logistic_regression.dtype.name, "num_iters": logistic_regression.solver_model.num_iters, } + + if len(logistic_regression.classes_) == 1: + if init_parameters["fit_intercept"] is True: + model["coef_"] = [[0.0] * logistic_regression.n_cols] + model["intercept_"] = [float("inf")] + del logistic_regression return model @@ -1027,6 +1033,16 @@ def _out_schema(self) -> Union[StructType, str]: ) def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel": + if len(result["classes_"]) == 1: + if self.getFitIntercept() is False: + print( + "WARNING: All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." + ) + else: + print( + "WARNING: All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." + ) + return LogisticRegressionModel._from_row(result) def _set_cuml_reg_params(self) -> "LogisticRegression": diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index f2f92d0f..845e2933 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -1076,3 +1076,71 @@ def test_parameters_validation() -> None: # charge of validating it. with pytest.raises(ValueError, match="C or regParam given invalid value -1.0"): LogisticRegression().setRegParam(-1.0).fit(df) + + +@pytest.mark.compat +@pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize( + "lr_types", + [ + (SparkLogisticRegression, SparkLogisticRegressionModel), + (LogisticRegression, LogisticRegressionModel), + ], +) +def test_compat_one_label( + fit_intercept: bool, + lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], + tmp_path: str, +) -> None: + tolerance = 0.001 + _LogisticRegression, _LogisticRegressionModel = lr_types + + X = np.array( + [ + [1.0, 2.0], + [1.0, 3.0], + [2.0, 1.0], + [3.0, 1.0], + ] + ) + y = np.array( + [ + 1.0, + 1.0, + 1.0, + 1.0, + ] + ) + num_rows = len(X) + + weight = np.ones([num_rows]) + feature_cols = ["c0", "c1"] + schema = ["c0 float, c1 float, weight float, label float"] + + with CleanSparkSession() as spark: + np_array = np.concatenate( + (X, weight.reshape(num_rows, 1), y.reshape(num_rows, 1)), axis=1 + ) + + bdf = spark.createDataFrame( + np_array.tolist(), + ",".join(schema), + ) + + bdf = bdf.withColumn("features", array_to_vector(array(*feature_cols))).drop( + *feature_cols + ) + + blor = _LogisticRegression( + regParam=0.1, fitIntercept=fit_intercept, standardization=False + ) + blor_model = blor.fit(bdf) + + if fit_intercept is False: + assert array_equal( + blor_model.coefficients.toArray(), [0.85431526, 0.85431526], tolerance + ) + assert blor_model.intercept == 0.0 + else: + assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0) + assert blor_model.intercept == float("inf") From 094481a532df04f27585dcc660b8935aaddedc0e Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Fri, 8 Dec 2023 17:50:33 -0800 Subject: [PATCH 2/7] revise PR --- python/src/spark_rapids_ml/classification.py | 8 +++- python/tests/test_logistic_regression.py | 41 +++++++++++--------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index a1c1af57..8b221f9d 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -978,9 +978,15 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: } if len(logistic_regression.classes_) == 1: + class_val = logistic_regression.classes_[0] + assert ( + class_val == 1.0 or class_val == 0.0 + ), "class value must be either 1. or 0. when dataset has one label" if init_parameters["fit_intercept"] is True: model["coef_"] = [[0.0] * logistic_regression.n_cols] - model["intercept_"] = [float("inf")] + model["intercept_"] = [ + float("inf") if class_val == 1.0 else float("-inf") + ] del logistic_regression return model diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index 845e2933..9c49f55f 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -1080,6 +1080,7 @@ def test_parameters_validation() -> None: @pytest.mark.compat @pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize("label", [1.0, 0.0]) @pytest.mark.parametrize( "lr_types", [ @@ -1089,9 +1090,11 @@ def test_parameters_validation() -> None: ) def test_compat_one_label( fit_intercept: bool, + label: float, lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], - tmp_path: str, ) -> None: + assert label == 1.0 or label == 0.0 + tolerance = 0.001 _LogisticRegression, _LogisticRegressionModel = lr_types @@ -1103,24 +1106,14 @@ def test_compat_one_label( [3.0, 1.0], ] ) - y = np.array( - [ - 1.0, - 1.0, - 1.0, - 1.0, - ] - ) + y = np.array([label] * 4) num_rows = len(X) - weight = np.ones([num_rows]) feature_cols = ["c0", "c1"] - schema = ["c0 float, c1 float, weight float, label float"] + schema = ["c0 float, c1 float, label float"] with CleanSparkSession() as spark: - np_array = np.concatenate( - (X, weight.reshape(num_rows, 1), y.reshape(num_rows, 1)), axis=1 - ) + np_array = np.concatenate((X, y.reshape(num_rows, 1)), axis=1) bdf = spark.createDataFrame( np_array.tolist(), @@ -1134,13 +1127,25 @@ def test_compat_one_label( blor = _LogisticRegression( regParam=0.1, fitIntercept=fit_intercept, standardization=False ) + blor_model = blor.fit(bdf) if fit_intercept is False: - assert array_equal( - blor_model.coefficients.toArray(), [0.85431526, 0.85431526], tolerance - ) + if label == 1.0: + assert array_equal( + blor_model.coefficients.toArray(), + [0.85431526, 0.85431526], + tolerance, + ) + else: + assert array_equal( + blor_model.coefficients.toArray(), + [-0.85431526, -0.85431526], + tolerance, + ) assert blor_model.intercept == 0.0 else: assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0) - assert blor_model.intercept == float("inf") + assert blor_model.intercept == ( + float("inf") if label == 1.0 else float("-inf") + ) From 427dcf28a4b69a035a56f335c55e40f33bac8d65 Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Wed, 13 Dec 2023 15:30:25 -0800 Subject: [PATCH 3/7] added label_val < 0, label_val > 1, and logging --- python/src/spark_rapids_ml/classification.py | 31 ++++++++++++-------- python/tests/test_logistic_regression.py | 30 +++++++++++++++++-- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index 8b221f9d..f79db819 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -979,14 +979,20 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: if len(logistic_regression.classes_) == 1: class_val = logistic_regression.classes_[0] - assert ( - class_val == 1.0 or class_val == 0.0 - ), "class value must be either 1. or 0. when dataset has one label" - if init_parameters["fit_intercept"] is True: - model["coef_"] = [[0.0] * logistic_regression.n_cols] - model["intercept_"] = [ - float("inf") if class_val == 1.0 else float("-inf") - ] + + if class_val < 0: + raise RuntimeError( + f"Labels MUST be in [0, 2147483647), but got {class_val}" + ) + elif class_val <= 1: + assert ( + class_val == 1.0 or class_val == 0.0 + ), "class value must be either 1. or 0. when dataset has one label" + if init_parameters["fit_intercept"] is True: + model["coef_"] = [[0.0] * logistic_regression.n_cols] + model["intercept_"] = [ + float("inf") if class_val == 1.0 else float("-inf") + ] del logistic_regression return model @@ -1039,14 +1045,15 @@ def _out_schema(self) -> Union[StructType, str]: ) def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel": + logger = get_logger(self.__class__) if len(result["classes_"]) == 1: if self.getFitIntercept() is False: - print( - "WARNING: All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." + logger.warning( + "All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." ) else: - print( - "WARNING: All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." + logger.warning( + "All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." ) return LogisticRegressionModel._from_row(result) diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index 9c49f55f..412cff32 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -1080,7 +1080,7 @@ def test_parameters_validation() -> None: @pytest.mark.compat @pytest.mark.parametrize("fit_intercept", [True, False]) -@pytest.mark.parametrize("label", [1.0, 0.0]) +@pytest.mark.parametrize("label", [1.0, 0.0, -3.0, 4.0]) @pytest.mark.parametrize( "lr_types", [ @@ -1093,7 +1093,7 @@ def test_compat_one_label( label: float, lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], ) -> None: - assert label == 1.0 or label == 0.0 + assert label % 1 == 0.0, "label value must be an integer" tolerance = 0.001 _LogisticRegression, _LogisticRegressionModel = lr_types @@ -1107,6 +1107,7 @@ def test_compat_one_label( ] ) y = np.array([label] * 4) + num_rows = len(X) feature_cols = ["c0", "c1"] @@ -1128,6 +1129,31 @@ def test_compat_one_label( regParam=0.1, fitIntercept=fit_intercept, standardization=False ) + if label < 0: + from py4j.protocol import Py4JJavaError + + msg = f"Labels MUST be in [0, 2147483647), but got {label}" + + try: + blor_model = blor.fit(bdf) + assert False, "There should be a java exception" + except Py4JJavaError as e: + assert msg in e.java_exception.getMessage() + + return + + if label > 1: # Spark and Cuml do not match + if _LogisticRegression is SparkLogisticRegression: + blor_model = blor.fit(bdf) + assert blor_model.numClasses == label + 1 + else: + blor_model = blor.fit(bdf) + assert blor_model.numClasses == 1 + + return + + assert label == 1.0 or label == 0.0 + blor_model = blor.fit(bdf) if fit_intercept is False: From d21f64cfb92be939172bcfd2dfcfd6c48f12ec6d Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Wed, 27 Dec 2023 11:49:59 -0800 Subject: [PATCH 4/7] add caplog to caputre warning message --- python/tests/test_logistic_regression.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index 412cff32..6954273b 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -1092,6 +1092,7 @@ def test_compat_one_label( fit_intercept: bool, label: float, lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], + caplog: LogCaptureFixture, ) -> None: assert label % 1 == 0.0, "label value must be an integer" @@ -1157,6 +1158,15 @@ def test_compat_one_label( blor_model = blor.fit(bdf) if fit_intercept is False: + if _LogisticRegression is SparkLogisticRegression: + # Got empty caplog.text. Spark prints warning message from jvm + assert caplog.text == "" + else: + assert ( + "All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." + in caplog.text + ) + if label == 1.0: assert array_equal( blor_model.coefficients.toArray(), @@ -1171,6 +1181,15 @@ def test_compat_one_label( ) assert blor_model.intercept == 0.0 else: + if _LogisticRegression is SparkLogisticRegression: + # Got empty caplog.text. Spark prints warning message from jvm + assert caplog.text == "" + else: + assert ( + "All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." + in caplog.text + ) + assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0) assert blor_model.intercept == ( float("inf") if label == 1.0 else float("-inf") From a3b418545317a4b9c7ba837d3e7498fbaf9c1535 Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Wed, 27 Dec 2023 13:53:37 -0800 Subject: [PATCH 5/7] add test_compat_wrong_label --- python/src/spark_rapids_ml/classification.py | 33 +++++---- python/tests/test_logistic_regression.py | 71 ++++++++++++++++++-- 2 files changed, 88 insertions(+), 16 deletions(-) diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index f79db819..594e8322 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -977,22 +977,31 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: "num_iters": logistic_regression.solver_model.num_iters, } - if len(logistic_regression.classes_) == 1: - class_val = logistic_regression.classes_[0] - + # check if invalid label exists + for class_val in logistic_regression.classes_: if class_val < 0: raise RuntimeError( f"Labels MUST be in [0, 2147483647), but got {class_val}" ) - elif class_val <= 1: - assert ( - class_val == 1.0 or class_val == 0.0 - ), "class value must be either 1. or 0. when dataset has one label" - if init_parameters["fit_intercept"] is True: - model["coef_"] = [[0.0] * logistic_regression.n_cols] - model["intercept_"] = [ - float("inf") if class_val == 1.0 else float("-inf") - ] + elif not class_val.is_integer(): + raise RuntimeError( + f"Labels MUST be Integers, but got {class_val}" + ) + + if len(logistic_regression.classes_) == 1: + class_val = logistic_regression.classes_[0] + # TODO: match Spark to use max(class_list) to calculate the number of classes + # Cuml currently uses unique(class_list) + if class_val != 1.0 and class_val != 0.0: + raise RuntimeError( + "class value must be either 1. or 0. when dataset has one label" + ) + + if init_parameters["fit_intercept"] is True: + model["coef_"] = [[0.0] * logistic_regression.n_cols] + model["intercept_"] = [ + float("inf") if class_val == 1.0 else float("-inf") + ] del logistic_regression return model diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index 6954273b..a581378d 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -6,6 +6,7 @@ import pytest from _pytest.logging import LogCaptureFixture from packaging import version +from py4j.protocol import Py4JJavaError if version.parse(pyspark.__version__) < version.parse("3.4.0"): from pyspark.sql.utils import IllegalArgumentException # type: ignore @@ -1131,8 +1132,6 @@ def test_compat_one_label( ) if label < 0: - from py4j.protocol import Py4JJavaError - msg = f"Labels MUST be in [0, 2147483647), but got {label}" try: @@ -1148,8 +1147,11 @@ def test_compat_one_label( blor_model = blor.fit(bdf) assert blor_model.numClasses == label + 1 else: - blor_model = blor.fit(bdf) - assert blor_model.numClasses == 1 + msg = "class value must be either 1. or 0. when dataset has one label" + try: + blor_model = blor.fit(bdf) + except Py4JJavaError as e: + assert msg in e.java_exception.getMessage() return @@ -1194,3 +1196,64 @@ def test_compat_one_label( assert blor_model.intercept == ( float("inf") if label == 1.0 else float("-inf") ) + + +@pytest.mark.compat +@pytest.mark.parametrize( + "lr_types", + [ + (SparkLogisticRegression, SparkLogisticRegressionModel), + (LogisticRegression, LogisticRegressionModel), + ], +) +def test_compat_wrong_label( + lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], + caplog: LogCaptureFixture, +) -> None: + _LogisticRegression, _LogisticRegressionModel = lr_types + + X = np.array( + [ + [1.0, 2.0], + [1.0, 3.0], + [2.0, 1.0], + [3.0, 1.0], + ] + ) + + num_rows = len(X) + feature_cols = ["c0", "c1"] + schema = ["c0 float, c1 float, label float"] + + def test_functor(y: np.ndarray, err_msg: str) -> None: + with CleanSparkSession() as spark: + np_array = np.concatenate((X, y.reshape(num_rows, 1)), axis=1) + + df = spark.createDataFrame( + np_array.tolist(), + ",".join(schema), + ) + + df = df.withColumn("features", array_to_vector(array(*feature_cols))).drop( + *feature_cols + ) + + lr = _LogisticRegression(standardization=False) + + try: + lr.fit(df) + assert False, "There should be a java exception" + except Py4JJavaError as e: + assert err_msg in e.java_exception.getMessage() + + # negative label + wrong_label = -1.1 + y = np.array([1.0, 0.0, wrong_label, 2.0]) + msg = f"Labels MUST be in [0, 2147483647), but got {wrong_label}" + test_functor(y, msg) + + # non-integer label + wrong_label = 0.4 + y = np.array([1.0, 0.0, wrong_label, 2.0]) + msg = f"Labels MUST be Integers, but got {wrong_label}" + test_functor(y, msg) From 5f1e494daf3570c7f3e814edda57177381dc7e45 Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Wed, 27 Dec 2023 13:59:16 -0800 Subject: [PATCH 6/7] update rapidsai-nightly channel to rapidsai --- ci/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/Dockerfile b/ci/Dockerfile index b0b47678..41dbdf0d 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -38,5 +38,5 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86 # install cuML ARG CUML_VER=23.12 -RUN conda install -y -c rapidsai-nightly -c conda-forge -c nvidia cuml=$CUML_VER python=3.9 cuda-version=11.8 \ +RUN conda install -y -c rapidsai -c conda-forge -c nvidia cuml=$CUML_VER python=3.9 cuda-version=11.8 \ && conda clean --all -f -y From ae63854a070f6ad74053c8f9043c406b5f05c080 Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Wed, 27 Dec 2023 15:35:33 -0800 Subject: [PATCH 7/7] fix ci failure due to class_val variable is a nparray instead of float. --- python/src/spark_rapids_ml/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index 594e8322..7bc96857 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -978,7 +978,7 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: } # check if invalid label exists - for class_val in logistic_regression.classes_: + for class_val in model["classes_"]: if class_val < 0: raise RuntimeError( f"Labels MUST be in [0, 2147483647), but got {class_val}"