-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LogisticRegression] Match Spark CPU behaviors when dataset has one label #531
Conversation
Signed-off-by: Jinfeng <[email protected]>
build |
if len(logistic_regression.classes_) == 1: | ||
if init_parameters["fit_intercept"] is True: | ||
model["coef_"] = [[0.0] * logistic_regression.n_cols] | ||
model["intercept_"] = [float("inf")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the sign of this depend on the label value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised to support -inf for label 0.
if len(result["classes_"]) == 1: | ||
if self.getFitIntercept() is False: | ||
print( | ||
"WARNING: All labels belong to a single class and fitIntercept=false. It's a dangerous ground, so the algorithm may not converge." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we match spark's warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to capture spark scala warning in python? I tried caplog.set_level() to INFO, WARN, CRITICAL but got empty log text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised to use logger.warning
) | ||
else: | ||
print( | ||
"WARNING: All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised
assert blor_model.intercept == 0.0 | ||
else: | ||
assert array_equal(blor_model.coefficients.toArray(), [0, 0], 0.0) | ||
assert blor_model.intercept == float("inf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe check what happens also in the case if all labels are 0 instead of 1 (i.e. y).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
class_val = logistic_regression.classes_[0] | ||
assert ( | ||
class_val == 1.0 or class_val == 0.0 | ||
), "class value must be either 1. or 0. when dataset has one label" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does spark do if label has one value but is not 1 or 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised.
if label < 0, a java runtimeError pops up.
If label > 1, spark trains a multinomial classification, cuml trains a single-class classification due to using y.unique().
build |
blor_model = blor.fit(bdf) | ||
|
||
if fit_intercept is False: | ||
if label == 1.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can caplog be used to check warning in this case? Like here: https://github.com/NVIDIA/spark-rapids-ml/blob/branch-23.12/python/tests/test_nearest_neighbors.py#L47
Any update on this? |
You will probably need to patch the ci docker image in this pr to get tests to pass as rapidsai-nightly no longer has cuml 23.12. switch to rapidsai channel. |
build |
Added the caplog, and a test case to check invalid label. Just updated ci docker image and yes seems ci can run. |
build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
No description provided.