Skip to content

Commit

Permalink
Handle cudf.pandas proxy objects properly (#11014)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Hyunsu Cho <[email protected]>
Co-authored-by: Jiaming Yuan <[email protected]>
  • Loading branch information
3 people authored Nov 25, 2024
1 parent e988b7c commit 0e48cdc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,6 +2669,7 @@ def inplace_predict(
_arrow_transform,
_is_arrow,
_is_cudf_df,
_is_cudf_pandas,
_is_cupy_alike,
_is_list,
_is_np_array_like,
Expand All @@ -2678,6 +2679,9 @@ def inplace_predict(
_transform_pandas_df,
)

if _is_cudf_pandas(data):
data = data._fsproxy_fast # pylint: disable=protected-access

enable_categorical = True
if _is_arrow(data):
data = _arrow_transform(data)
Expand Down
16 changes: 16 additions & 0 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,16 @@ def _is_cudf_df(data: DataType) -> bool:
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")


def _is_cudf_pandas(data: DataType) -> bool:
"""Must go before both pandas and cudf checks."""
return (
lazy_isinstance(data, "pandas.core.frame", "DataFrame")
or lazy_isinstance(data, "pandas.core.series", "Series")
) and lazy_isinstance(
type(data), "cudf.pandas.fast_slow_proxy", "_FastSlowProxyMeta"
)


def _get_cudf_cat_predicate() -> Callable[[Any], bool]:
try:
from cudf import CategoricalDtype
Expand Down Expand Up @@ -1237,6 +1247,8 @@ def dispatch_data_backend(
)
if _is_arrow(data):
data = _arrow_transform(data)
if _is_cudf_pandas(data):
data = data._fsproxy_fast # pylint: disable=protected-access
if _is_pandas_series(data):
import pandas as pd

Expand Down Expand Up @@ -1409,6 +1421,8 @@ def dispatch_meta_backend(
return
if _is_arrow(data):
data = _arrow_transform(data)
if _is_cudf_pandas(data):
data = data._fsproxy_fast # pylint: disable=protected-access
if _is_pandas_df(data):
_meta_from_pandas_df(data, name, dtype=dtype, handle=handle)
return
Expand Down Expand Up @@ -1480,6 +1494,8 @@ def _proxy_transform(
feature_types: Optional[FeatureTypes],
enable_categorical: bool,
) -> TransformedData:
if _is_cudf_pandas(data):
data = data._fsproxy_fast # pylint: disable=protected-access
if _is_cudf_df(data) or _is_cudf_ser(data):
return _transform_cudf_df(
data, feature_names, feature_types, enable_categorical
Expand Down

0 comments on commit 0e48cdc

Please sign in to comment.