Skip to content

Commit

Permalink
Validate reference dataset for training. (#11105)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 18, 2024
1 parent f06dcf8 commit dc092ae
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 14 deletions.
23 changes: 21 additions & 2 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,20 @@ def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None:
)


class QuantileDMatrix(DMatrix):
class _RefMixIn:
@property
def ref(self) -> Optional[weakref.ReferenceType]:
"""Internal method for retrieving a reference to the training DMatrix."""
if hasattr(self, "_ref"):
return self._ref
return None

@ref.setter
def ref(self, ref: weakref.ReferenceType) -> None:
self._ref = ref


class QuantileDMatrix(DMatrix, _RefMixIn):
"""A DMatrix variant that generates quantilized data directly from input for the
``hist`` tree method. This DMatrix is primarily designed to save memory in training
by avoiding intermediate storage. Set ``max_bin`` to control the number of bins
Expand Down Expand Up @@ -1640,8 +1653,11 @@ def _init(
_check_call(ret)
self.handle = handle

if ref is not None:
self.ref = weakref.ref(ref)


class ExtMemQuantileDMatrix(DMatrix):
class ExtMemQuantileDMatrix(DMatrix, _RefMixIn):
"""The external memory version of the :py:class:`QuantileDMatrix`.
See :doc:`/tutorials/external_memory` for explanation and usage examples, and
Expand Down Expand Up @@ -1739,6 +1755,9 @@ def _init(
_check_call(ret)
self.handle = handle

if ref is not None:
self.ref = weakref.ref(ref)


Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
Expand Down
15 changes: 15 additions & 0 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Training Library containing training routines."""
import copy
import os
import weakref
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -147,6 +148,20 @@ def train(
callbacks = [] if callbacks is None else copy.copy(list(callbacks))
evals = list(evals) if evals else []

for va, _ in evals:
if not isinstance(va, DMatrix):
raise TypeError("Invalid type for the `evals`.")

if (
hasattr(va, "ref")
and va.ref is not weakref.ref(dtrain)
and va is not dtrain
):
raise ValueError(
"Training dataset should be used as a reference when constructing "
"the `QuantileDMatrix` for evaluation."
)

bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
start_iteration = 0

Expand Down
2 changes: 1 addition & 1 deletion tests/python-gpu/test_device_quantile_dmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_ref_dmatrix(self) -> None:
import cupy as cp

rng = cp.random.RandomState(np.uint64(1994))
self.cputest.run_ref_dmatrix(rng, "gpu_hist", False)
self.cputest.run_ref_dmatrix(rng, "cuda", False)

@given(
strategies.integers(1, 1000),
Expand Down
30 changes: 19 additions & 11 deletions tests/python/test_quantile_dmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def test_training(self, sparsity: float) -> None:
}
xgb.train(parameters, Xy)

def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
def run_ref_dmatrix(self, rng: Any, device: str, enable_cat: bool) -> None:
n_samples, n_features = 2048, 17
if enable_cat:
X, y = make_categorical(
n_samples, n_features, n_categories=13, onehot=False
)
if tree_method == "gpu_hist":
if device == "cuda":
import cudf

X = cudf.from_pandas(X)
Expand All @@ -189,10 +189,12 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:

# Use ref
Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat)
Xy_valid: xgb.DMatrix = xgb.QuantileDMatrix(
X, y, ref=Xy, enable_categorical=enable_cat
)
qdm_results: Dict[str, Dict[str, List[float]]] = {}
xgb.train(
{"tree_method": tree_method},
{"tree_method": "hist", "device": device},
Xy,
evals=[(Xy, "Train"), (Xy_valid, "valid")],
evals_result=qdm_results,
Expand All @@ -201,10 +203,10 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"]
)
# No ref
Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
Xy_valid = xgb.DMatrix(X, y, enable_categorical=enable_cat)
qdm_results = {}
xgb.train(
{"tree_method": tree_method},
{"tree_method": "hist", "device": device},
Xy,
evals=[(Xy, "Train"), (Xy_valid, "valid")],
evals_result=qdm_results,
Expand All @@ -229,7 +231,7 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
n_samples, n_features = 256, 17
if enable_cat:
X, y = make_categorical(n_samples, n_features, 13, onehot=False)
if tree_method == "gpu_hist":
if device == "cuda":
import cudf

X = cudf.from_pandas(X)
Expand All @@ -246,15 +248,15 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:

qdm_results = {}
xgb.train(
{"tree_method": tree_method},
{"tree_method": "hist", "device": device},
Xy,
evals=[(Xy, "Train"), (Xy_valid, "valid")],
evals_result=qdm_results,
)

dm_results: Dict[str, Dict[str, List[float]]] = {}
xgb.train(
{"tree_method": tree_method},
{"tree_method": "hist", "device": device},
dXy,
evals=[(dXy, "Train"), (dXy_valid, "valid"), (Xy_valid_d, "dvalid")],
evals_result=dm_results,
Expand All @@ -269,13 +271,19 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None:
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
)

Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat)
with pytest.raises(ValueError, match="should be used as a reference"):
xgb.train(
{"device": device}, dXy, evals=[(dXy, "Train"), (Xy_valid, "Valid")]
)

def test_ref_quantile_cut(self) -> None:
check_ref_quantile_cut("cpu")

def test_ref_dmatrix(self) -> None:
rng = np.random.RandomState(1994)
self.run_ref_dmatrix(rng, "hist", True)
self.run_ref_dmatrix(rng, "hist", False)
self.run_ref_dmatrix(rng, "cpu", True)
self.run_ref_dmatrix(rng, "cpu", False)

@pytest.mark.parametrize("sparsity", [0.0, 0.5])
def test_predict(self, sparsity: float) -> None:
Expand Down

0 comments on commit dc092ae

Please sign in to comment.