diff --git a/python/src/spark_rapids_ml/classification.py b/python/src/spark_rapids_ml/classification.py index 878d6256..a1c1af57 100644 --- a/python/src/spark_rapids_ml/classification.py +++ b/python/src/spark_rapids_ml/classification.py @@ -976,6 +976,12 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: "dtype": logistic_regression.dtype.name, "num_iters": logistic_regression.solver_model.num_iters, } + + if len(logistic_regression.classes_) == 1: + if init_parameters["fit_intercept"] is True: + model["coef_"] = [[0.0] * logistic_regression.n_cols] + model["intercept_"] = [float("inf")] + del logistic_regression return model @@ -1027,6 +1033,16 @@ def _out_schema(self) -> Union[StructType, str]: ) def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel": + if len(result["classes_"]) == 1: + if self.getFitIntercept() is False: + print( + "WARNING: All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." + ) + else: + print( + "WARNING: All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." + ) + return LogisticRegressionModel._from_row(result) def _set_cuml_reg_params(self) -> "LogisticRegression": diff --git a/python/tests/test_logistic_regression.py b/python/tests/test_logistic_regression.py index f2f92d0f..62855b03 100644 --- a/python/tests/test_logistic_regression.py +++ b/python/tests/test_logistic_regression.py @@ -1076,3 +1076,71 @@ def test_parameters_validation() -> None: # charge of validating it. with pytest.raises(ValueError, match="C or regParam given invalid value -1.0"): LogisticRegression().setRegParam(-1.0).fit(df) + + +@pytest.mark.compat +@pytest.mark.parametrize("fit_intercept", [True, False]) +@pytest.mark.parametrize( + "lr_types", + [ + # (SparkLogisticRegression, SparkLogisticRegressionModel), + (LogisticRegression, LogisticRegressionModel), + ], +) +def test_compat_one_label( + fit_intercept: bool, + lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], + tmp_path: str, +) -> None: + tolerance = 0.001 + _LogisticRegression, _LogisticRegressionModel = lr_types + + X = np.array( + [ + [1.0, 2.0], + [1.0, 3.0], + [2.0, 1.0], + [3.0, 1.0], + ] + ) + y = np.array( + [ + 1.0, + 1.0, + 1.0, + 1.0, + ] + ) + num_rows = len(X) + + weight = np.ones([num_rows]) + feature_cols = ["c0", "c1"] + schema = ["c0 float, c1 float, weight float, label float"] + + with CleanSparkSession() as spark: + np_array = np.concatenate( + (X, weight.reshape(num_rows, 1), y.reshape(num_rows, 1)), axis=1 + ) + + bdf = spark.createDataFrame( + np_array.tolist(), + ",".join(schema), + ) + + bdf = bdf.withColumn("features", array_to_vector(array(*feature_cols))).drop( + *feature_cols + ) + + blor = _LogisticRegression( + regParam=0.1, fitIntercept=fit_intercept, standardization=False + ) + blor_model = blor.fit(bdf) + + if fit_intercept is False: + assert array_equal( + blor_model.coefficients.toArray(), [0.85431526, 0.85431526], tolerance + ) + assert blor_model.intercept == 0.0 + else: + assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0) + assert blor_model.intercept == float("inf")