Skip to content

Commit

Permalink
More sklearn tag support. (#11162)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 13, 2025
1 parent 3a2a85d commit d43c759
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions ops/docker_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def docker_run(
docker_run_cli_args.extend(
itertools.chain.from_iterable([["-e", f"{k}={v}"] for k, v in user_ids.items()])
)
docker_run_cli_args.extend(["-e", "NCCL_RAS_ENABLE=0"])
docker_run_cli_args.extend(extra_args)
docker_run_cli_args.append(image_uri)
docker_run_cli_args.extend(command_args)
Expand Down
10 changes: 8 additions & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,11 @@ def __init__(

def _more_tags(self) -> Dict[str, bool]:
"""Tags used for scikit-learn data validation."""
tags = {"allow_nan": True, "no_validation": True}
tags = {"allow_nan": True, "no_validation": True, "sparse": True}
if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun":
tags["non_deterministic"] = True

tags["categorical"] = self.enable_categorical
return tags

@staticmethod
Expand All @@ -826,11 +828,15 @@ def _update_sklearn_tags_from_dict(
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
ref: https://github.com/scikit-learn/scikit-learn/pull/29677
This method handles updating that instance based on the values in ``self._more_tags()``.
This method handles updating that instance based on the values in
``self._more_tags()``.
"""
tags.non_deterministic = tags_dict.get("non_deterministic", False)
tags.no_validation = tags_dict["no_validation"]
tags.input_tags.allow_nan = tags_dict["allow_nan"]
tags.input_tags.sparse = tags_dict["sparse"]
tags.input_tags.categorical = tags_dict["categorical"]
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def test_data_initialization() -> None:
validate_data_initialization(xgb.QuantileDMatrix, xgb.XGBClassifier, X, y)


@parametrize_with_checks([xgb.XGBRegressor()])
@parametrize_with_checks([xgb.XGBRegressor(enable_categorical=True)])
def test_estimator_reg(estimator, check):
if os.environ["PYTEST_CURRENT_TEST"].find("check_supervised_y_no_nan") != -1:
# The test uses float64 and requires the error message to contain:
Expand Down

0 comments on commit d43c759

Please sign in to comment.