Skip to content

Commit

Permalink
revise PR
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Dec 9, 2023
1 parent 3bb0d83 commit 094481a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
8 changes: 7 additions & 1 deletion python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,9 +978,15 @@ def _single_fit(init_parameters: Dict[str, Any]) -> Dict[str, Any]:
}

if len(logistic_regression.classes_) == 1:
class_val = logistic_regression.classes_[0]
assert (
class_val == 1.0 or class_val == 0.0
), "class value must be either 1. or 0. when dataset has one label"
if init_parameters["fit_intercept"] is True:
model["coef_"] = [[0.0] * logistic_regression.n_cols]
model["intercept_"] = [float("inf")]
model["intercept_"] = [
float("inf") if class_val == 1.0 else float("-inf")
]

del logistic_regression
return model
Expand Down
41 changes: 23 additions & 18 deletions python/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,7 @@ def test_parameters_validation() -> None:

@pytest.mark.compat
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("label", [1.0, 0.0])
@pytest.mark.parametrize(
"lr_types",
[
Expand All @@ -1089,9 +1090,11 @@ def test_parameters_validation() -> None:
)
def test_compat_one_label(
fit_intercept: bool,
label: float,
lr_types: Tuple[LogisticRegressionType, LogisticRegressionModelType],
tmp_path: str,
) -> None:
assert label == 1.0 or label == 0.0

tolerance = 0.001
_LogisticRegression, _LogisticRegressionModel = lr_types

Expand All @@ -1103,24 +1106,14 @@ def test_compat_one_label(
[3.0, 1.0],
]
)
y = np.array(
[
1.0,
1.0,
1.0,
1.0,
]
)
y = np.array([label] * 4)
num_rows = len(X)

weight = np.ones([num_rows])
feature_cols = ["c0", "c1"]
schema = ["c0 float, c1 float, weight float, label float"]
schema = ["c0 float, c1 float, label float"]

with CleanSparkSession() as spark:
np_array = np.concatenate(
(X, weight.reshape(num_rows, 1), y.reshape(num_rows, 1)), axis=1
)
np_array = np.concatenate((X, y.reshape(num_rows, 1)), axis=1)

bdf = spark.createDataFrame(
np_array.tolist(),
Expand All @@ -1134,13 +1127,25 @@ def test_compat_one_label(
blor = _LogisticRegression(
regParam=0.1, fitIntercept=fit_intercept, standardization=False
)

blor_model = blor.fit(bdf)

if fit_intercept is False:
assert array_equal(
blor_model.coefficients.toArray(), [0.85431526, 0.85431526], tolerance
)
if label == 1.0:
assert array_equal(
blor_model.coefficients.toArray(),
[0.85431526, 0.85431526],
tolerance,
)
else:
assert array_equal(
blor_model.coefficients.toArray(),
[-0.85431526, -0.85431526],
tolerance,
)
assert blor_model.intercept == 0.0
else:
assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0)
assert blor_model.intercept == float("inf")
assert blor_model.intercept == (
float("inf") if label == 1.0 else float("-inf")
)

0 comments on commit 094481a

Please sign in to comment.