-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LogisticRegression] Match Spark CPU behaviors when dataset has one label #531
Changes from 2 commits
3bb0d83
094481a
427dcf2
d21f64c
a3b4185
5f1e494
ae63854
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -976,6 +976,18 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]: | |
"dtype": logistic_regression.dtype.name, | ||
"num_iters": logistic_regression.solver_model.num_iters, | ||
} | ||
|
||
if len(logistic_regression.classes_) == 1: | ||
class_val = logistic_regression.classes_[0] | ||
assert ( | ||
class_val == 1.0 or class_val == 0.0 | ||
), "class value must be either 1. or 0. when dataset has one label" | ||
if init_parameters["fit_intercept"] is True: | ||
model["coef_"] = [[0.0] * logistic_regression.n_cols] | ||
model["intercept_"] = [ | ||
float("inf") if class_val == 1.0 else float("-inf") | ||
] | ||
|
||
del logistic_regression | ||
return model | ||
|
||
|
@@ -1027,6 +1039,16 @@ def _out_schema(self) -> Union[StructType, str]: | |
) | ||
|
||
def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel": | ||
if len(result["classes_"]) == 1: | ||
if self.getFitIntercept() is False: | ||
print( | ||
"WARNING: All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we match spark's warning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to capture spark scala warning in python? I tried caplog.set_level() to INFO, WARN, CRITICAL but got empty log text. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revised to use logger.warning |
||
) | ||
else: | ||
print( | ||
"WARNING: All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revised |
||
) | ||
|
||
return LogisticRegressionModel._from_row(result) | ||
|
||
def _set_cuml_reg_params(self) -> "LogisticRegression": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1076,3 +1076,76 @@ def test_parameters_validation() -> None: | |
# charge of validating it. | ||
with pytest.raises(ValueError, match="C or regParam given invalid value -1.0"): | ||
LogisticRegression().setRegParam(-1.0).fit(df) | ||
|
||
|
||
@pytest.mark.compat | ||
@pytest.mark.parametrize("fit_intercept", [True, False]) | ||
@pytest.mark.parametrize("label", [1.0, 0.0]) | ||
@pytest.mark.parametrize( | ||
"lr_types", | ||
[ | ||
(SparkLogisticRegression, SparkLogisticRegressionModel), | ||
(LogisticRegression, LogisticRegressionModel), | ||
], | ||
) | ||
def test_compat_one_label( | ||
fit_intercept: bool, | ||
label: float, | ||
lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType], | ||
) -> None: | ||
assert label == 1.0 or label == 0.0 | ||
|
||
tolerance = 0.001 | ||
_LogisticRegression, _LogisticRegressionModel = lr_types | ||
|
||
X = np.array( | ||
[ | ||
[1.0, 2.0], | ||
[1.0, 3.0], | ||
[2.0, 1.0], | ||
[3.0, 1.0], | ||
] | ||
) | ||
y = np.array([label] * 4) | ||
num_rows = len(X) | ||
|
||
feature_cols = ["c0", "c1"] | ||
schema = ["c0 float, c1 float, label float"] | ||
|
||
with CleanSparkSession() as spark: | ||
np_array = np.concatenate((X, y.reshape(num_rows, 1)), axis=1) | ||
|
||
bdf = spark.createDataFrame( | ||
np_array.tolist(), | ||
",".join(schema), | ||
) | ||
|
||
bdf = bdf.withColumn("features", array_to_vector(array(*feature_cols))).drop( | ||
*feature_cols | ||
) | ||
|
||
blor = _LogisticRegression( | ||
regParam=0.1, fitIntercept=fit_intercept, standardization=False | ||
) | ||
|
||
blor_model = blor.fit(bdf) | ||
|
||
if fit_intercept is False: | ||
if label == 1.0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can caplog be used to check warning in this case? Like here: https://github.com/NVIDIA/spark-rapids-ml/blob/branch-23.12/python/tests/test_nearest_neighbors.py#L47 |
||
assert array_equal( | ||
blor_model.coefficients.toArray(), | ||
[0.85431526, 0.85431526], | ||
tolerance, | ||
) | ||
else: | ||
assert array_equal( | ||
blor_model.coefficients.toArray(), | ||
[-0.85431526, -0.85431526], | ||
tolerance, | ||
) | ||
assert blor_model.intercept == 0.0 | ||
else: | ||
assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0) | ||
assert blor_model.intercept == ( | ||
float("inf") if label == 1.0 else float("-inf") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does spark do if label has one value but is not 1 or 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised.
if label < 0, a java runtimeError pops up.
If label > 1, spark trains a multinomial classification, cuml trains a single-class classification due to using y.unique().