Skip to content

Commit

Permalink
Revert "XGBoost - Use DaskDMatrix for evals data to ensure metrics in…
Browse files Browse the repository at this point in the history
… logs match result of evaluate (#682)" (#762)

This reverts commit 489137e.
  • Loading branch information
oliverholworthy authored Sep 23, 2022
1 parent 461484d commit 96f31b8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
5 changes: 1 addition & 4 deletions merlin/models/xgb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,7 @@ def fit(
self.target_columns,
self.qid_column,
)
# using the quantile DMatrix as part of evals results in a
# discrepancy between metrics reported in logs and result
# of evaluate
d_eval = xgb.dask.DaskDMatrix(self.dask_client, X, label=y, qid=qid)
d_eval = dmatrix_cls(self.dask_client, X, label=y, qid=qid)
watchlist.append((d_eval, name))

train_res = xgb.dask.train(
Expand Down
16 changes: 7 additions & 9 deletions tests/unit/xgb/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,24 @@ def test_pairwise(self, social_data: Dataset):
],
)
@patch("xgboost.dask.train", side_effect=xgboost.dask.train)
def test_gpu_hist_dmatrix(mock_train, fit_kwargs, expected_dtrain_cls, dask_client):
train, valid = generate_data("music-streaming", num_rows=100, set_sizes=(0.5, 0.5))
schema = train.schema
def test_gpu_hist_dmatrix(
mock_train, fit_kwargs, expected_dtrain_cls, dask_client, music_streaming_data: Dataset
):
schema = music_streaming_data.schema
model = XGBoost(schema, objective="reg:logistic", tree_method="gpu_hist")
model.fit(train, evals=[(valid, "valid")], **fit_kwargs)
model.predict(valid)
metrics = model.evaluate(valid)
model.fit(music_streaming_data, **fit_kwargs)
model.predict(music_streaming_data)
metrics = model.evaluate(music_streaming_data)
assert "rmse" in metrics

assert mock_train.called
assert mock_train.call_count == 1

train_call = mock_train.call_args_list[0]
client, params, dtrain = train_call.args
evals = train_call.kwargs["evals"]
assert dask_client == client
assert params["tree_method"] == "gpu_hist"
assert params["objective"] == "reg:logistic"
# check that we don't use quantile dmatrix for non-training eval data
assert not isinstance(evals[0][0], xgboost.dask.DaskDeviceQuantileDMatrix)
assert isinstance(dtrain, expected_dtrain_cls)


Expand Down

0 comments on commit 96f31b8

Please sign in to comment.