Skip to content

Commit

Permalink
support one_label
Browse files Browse the repository at this point in the history
Signed-off-by: Jinfeng <[email protected]>
  • Loading branch information
lijinf2 committed Dec 8, 2023
1 parent c2b997f commit 4d5dd40
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,12 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
"dtype": logistic_regression.dtype.name,
"num_iters": logistic_regression.solver_model.num_iters,
}

if len(logistic_regression.classes_) == 1:
if init_parameters["fit_intercept"] is True:
model["coef_"] = [[0.0] * logistic_regression.n_cols]
model["intercept_"] = [float("inf")]

del logistic_regression
return model

Expand Down Expand Up @@ -1027,6 +1033,16 @@ def _out_schema(self) -> Union[StructType, str]:
)

def _create_pyspark_model(self, result: Row) -> "LogisticRegressionModel":
if len(result["classes_"]) == 1:
if self.getFitIntercept() is False:
print(
"WARNING: All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge."
)
else:
print(
"WARNING: All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed."
)

return LogisticRegressionModel._from_row(result)

def _set_cuml_reg_params(self) -> "LogisticRegression":
Expand Down
68 changes: 68 additions & 0 deletions python/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,3 +1076,71 @@ def test_parameters_validation() -> None:
# charge of validating it.
with pytest.raises(ValueError, match="C or regParam given invalid value -1.0"):
LogisticRegression().setRegParam(-1.0).fit(df)


@pytest.mark.compat
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize(
"lr_types",
[
# (SparkLogisticRegression, SparkLogisticRegressionModel),
(LogisticRegression, LogisticRegressionModel),
],
)
def test_compat_one_label(
fit_intercept: bool,
lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType],
tmp_path: str,
) -> None:
tolerance = 0.001
_LogisticRegression, _LogisticRegressionModel = lr_types

X = np.array(
[
[1.0, 2.0],
[1.0, 3.0],
[2.0, 1.0],
[3.0, 1.0],
]
)
y = np.array(
[
1.0,
1.0,
1.0,
1.0,
]
)
num_rows = len(X)

weight = np.ones([num_rows])
feature_cols = ["c0", "c1"]
schema = ["c0 float, c1 float, weight float, label float"]

with CleanSparkSession() as spark:
np_array = np.concatenate(
(X, weight.reshape(num_rows, 1), y.reshape(num_rows, 1)), axis=1
)

bdf = spark.createDataFrame(
np_array.tolist(),
",".join(schema),
)

bdf = bdf.withColumn("features", array_to_vector(array(*feature_cols))).drop(
*feature_cols
)

blor = _LogisticRegression(
regParam=0.1, fitIntercept=fit_intercept, standardization=False
)
blor_model = blor.fit(bdf)

if fit_intercept is False:
assert array_equal(
blor_model.coefficients.toArray(), [0.85431526, 0.85431526], tolerance
)
assert blor_model.intercept == 0.0
else:
assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0)
assert blor_model.intercept == float("inf")

0 comments on commit 4d5dd40

Please sign in to comment.