From dc092ae8a28de6b6e3426dd0dc391645cb5eac0d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 18 Dec 2024 12:12:52 +0800 Subject: [PATCH] Validate reference dataset for training. (#11105) --- python-package/xgboost/core.py | 23 ++++++++++++-- python-package/xgboost/training.py | 15 ++++++++++ .../test_device_quantile_dmatrix.py | 2 +- tests/python/test_quantile_dmatrix.py | 30 ++++++++++++------- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ad0c77edeae6..f2d4fb548c8e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1451,7 +1451,20 @@ def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None: ) -class QuantileDMatrix(DMatrix): +class _RefMixIn: + @property + def ref(self) -> Optional[weakref.ReferenceType]: + """Internal method for retrieving a reference to the training DMatrix.""" + if hasattr(self, "_ref"): + return self._ref + return None + + @ref.setter + def ref(self, ref: weakref.ReferenceType) -> None: + self._ref = ref + + +class QuantileDMatrix(DMatrix, _RefMixIn): """A DMatrix variant that generates quantilized data directly from input for the ``hist`` tree method. This DMatrix is primarily designed to save memory in training by avoiding intermediate storage. Set ``max_bin`` to control the number of bins @@ -1640,8 +1653,11 @@ def _init( _check_call(ret) self.handle = handle + if ref is not None: + self.ref = weakref.ref(ref) + -class ExtMemQuantileDMatrix(DMatrix): +class ExtMemQuantileDMatrix(DMatrix, _RefMixIn): """The external memory version of the :py:class:`QuantileDMatrix`. See :doc:`/tutorials/external_memory` for explanation and usage examples, and @@ -1739,6 +1755,9 @@ def _init( _check_call(ret) self.handle = handle + if ref is not None: + self.ref = weakref.ref(ref) + Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 29a516e81e24..3379df3add3b 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -3,6 +3,7 @@ """Training Library containing training routines.""" import copy import os +import weakref from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -147,6 +148,20 @@ def train( callbacks = [] if callbacks is None else copy.copy(list(callbacks)) evals = list(evals) if evals else [] + for va, _ in evals: + if not isinstance(va, DMatrix): + raise TypeError("Invalid type for the `evals`.") + + if ( + hasattr(va, "ref") + and va.ref is not weakref.ref(dtrain) + and va is not dtrain + ): + raise ValueError( + "Training dataset should be used as a reference when constructing " + "the `QuantileDMatrix` for evaluation." + ) + bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) start_iteration = 0 diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index d789dfab25e2..2f2e6545bf2d 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -175,7 +175,7 @@ def test_ref_dmatrix(self) -> None: import cupy as cp rng = cp.random.RandomState(np.uint64(1994)) - self.cputest.run_ref_dmatrix(rng, "gpu_hist", False) + self.cputest.run_ref_dmatrix(rng, "cuda", False) @given( strategies.integers(1, 1000), diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index 19bce7317c66..e1152370732b 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -170,13 +170,13 @@ def test_training(self, sparsity: float) -> None: } xgb.train(parameters, Xy) - def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: + def run_ref_dmatrix(self, rng: Any, device: str, enable_cat: bool) -> None: n_samples, n_features = 2048, 17 if enable_cat: X, y = make_categorical( n_samples, n_features, n_categories=13, onehot=False ) - if tree_method == "gpu_hist": + if device == "cuda": import cudf X = cudf.from_pandas(X) @@ -189,10 +189,12 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: # Use ref Xy = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat) - Xy_valid = xgb.QuantileDMatrix(X, y, ref=Xy, enable_categorical=enable_cat) + Xy_valid: xgb.DMatrix = xgb.QuantileDMatrix( + X, y, ref=Xy, enable_categorical=enable_cat + ) qdm_results: Dict[str, Dict[str, List[float]]] = {} xgb.train( - {"tree_method": tree_method}, + {"tree_method": "hist", "device": device}, Xy, evals=[(Xy, "Train"), (Xy_valid, "valid")], evals_result=qdm_results, @@ -201,10 +203,10 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: qdm_results["Train"]["rmse"], qdm_results["valid"]["rmse"] ) # No ref - Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat) + Xy_valid = xgb.DMatrix(X, y, enable_categorical=enable_cat) qdm_results = {} xgb.train( - {"tree_method": tree_method}, + {"tree_method": "hist", "device": device}, Xy, evals=[(Xy, "Train"), (Xy_valid, "valid")], evals_result=qdm_results, @@ -229,7 +231,7 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: n_samples, n_features = 256, 17 if enable_cat: X, y = make_categorical(n_samples, n_features, 13, onehot=False) - if tree_method == "gpu_hist": + if device == "cuda": import cudf X = cudf.from_pandas(X) @@ -246,7 +248,7 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: qdm_results = {} xgb.train( - {"tree_method": tree_method}, + {"tree_method": "hist", "device": device}, Xy, evals=[(Xy, "Train"), (Xy_valid, "valid")], evals_result=qdm_results, @@ -254,7 +256,7 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: dm_results: Dict[str, Dict[str, List[float]]] = {} xgb.train( - {"tree_method": tree_method}, + {"tree_method": "hist", "device": device}, dXy, evals=[(dXy, "Train"), (dXy_valid, "valid"), (Xy_valid_d, "dvalid")], evals_result=dm_results, @@ -269,13 +271,19 @@ def run_ref_dmatrix(self, rng: Any, tree_method: str, enable_cat: bool) -> None: dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"] ) + Xy_valid = xgb.QuantileDMatrix(X, y, enable_categorical=enable_cat) + with pytest.raises(ValueError, match="should be used as a reference"): + xgb.train( + {"device": device}, dXy, evals=[(dXy, "Train"), (Xy_valid, "Valid")] + ) + def test_ref_quantile_cut(self) -> None: check_ref_quantile_cut("cpu") def test_ref_dmatrix(self) -> None: rng = np.random.RandomState(1994) - self.run_ref_dmatrix(rng, "hist", True) - self.run_ref_dmatrix(rng, "hist", False) + self.run_ref_dmatrix(rng, "cpu", True) + self.run_ref_dmatrix(rng, "cpu", False) @pytest.mark.parametrize("sparsity", [0.0, 0.5]) def test_predict(self, sparsity: float) -> None: