diff --git a/README.rst b/README.rst index 36d71dba..ba6c4a08 100644 --- a/README.rst +++ b/README.rst @@ -40,7 +40,7 @@ It provides support for the following machine learning frameworks and packages: XGBRegressor and xgboost.Booster. * LightGBM_ - show feature importances and explain predictions of - LGBMClassifier and LGBMRegressor. + LGBMClassifier, LGBMRegressor and lightgbm.Booster. * lightning_ - explain weights and predictions of lightning classifiers and regressors. diff --git a/docs/source/libraries/lightgbm.rst b/docs/source/libraries/lightgbm.rst index 6d612d47..c5031da3 100644 --- a/docs/source/libraries/lightgbm.rst +++ b/docs/source/libraries/lightgbm.rst @@ -5,13 +5,12 @@ LightGBM LightGBM_ is a fast Gradient Boosting framework; it provides a Python interface. eli5 supports :func:`eli5.explain_weights` -and :func:`eli5.explain_prediction` for ``lightgbm.LGBMClassifer`` -and ``lightgbm.LGBMRegressor`` estimators. +and :func:`eli5.explain_prediction` for ``lightgbm.LGBMClassifer``, ``lightgbm.LGBMRegressor`` and ``lightgbm.Booster`` estimators. .. _LightGBM: https://github.com/Microsoft/LightGBM :func:`eli5.explain_weights` uses feature importances. Additional -arguments for LGBMClassifier and LGBMClassifier: +arguments for LGBMClassifier , LGBMClassifier and lightgbm.Booster: * ``importance_type`` is a way to get feature importance. Possible values are: @@ -22,7 +21,7 @@ arguments for LGBMClassifier and LGBMClassifier: - 'weight' - the same as 'split', for better compatibility with :ref:`library-xgboost`. -``target_names`` and ``target`` arguments are ignored. +``target_names`` arguement is ignored for ``lightgbm.LGBMClassifer`` / ``lightgbm.LGBMRegressor``, but used for ``lightgbm.Booster``. ``target`` argument is ignored. .. note:: Top-level :func:`eli5.explain_weights` calls are dispatched @@ -37,7 +36,7 @@ contribution of a feature on the decision path is how much the score changes from parent to child. Additional :func:`eli5.explain_prediction` keyword arguments supported -for ``lightgbm.LGBMClassifer`` and ``lightgbm.LGBMRegressor``: +for ``lightgbm.LGBMClassifer``, ``lightgbm.LGBMRegressor`` and ``lightgbm.Booster``: * ``vec`` is a vectorizer instance used to transform raw features to the input of the estimator ``lgb`` @@ -50,6 +49,14 @@ for ``lightgbm.LGBMClassifer`` and ``lightgbm.LGBMRegressor``: estimator. Set it to True if you're passing ``vec``, but ``doc`` is already vectorized. +``lightgbm.Booster`` estimator accepts one more optional argument: + +* ``is_regression`` - True if solving a regression problem + ("objective" starts with "reg") + and False for a classification problem. + If not set, regression is assumed for a single target estimator + and proba will not be shown unless the ``target_names`` is defined as a list with length of two. + .. note:: Top-level :func:`eli5.explain_prediction` calls are dispatched to :func:`eli5.xgboost.explain_prediction_lightgbm` for diff --git a/docs/source/overview.rst b/docs/source/overview.rst index a11ed2fd..49a4b7f5 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -31,7 +31,7 @@ following machine learning frameworks and packages: of XGBClassifier, XGBRegressor and xgboost.Booster. * :ref:`library-lightgbm` - show feature importances and explain predictions - of LGBMClassifier and LGBMRegressor. + of LGBMClassifier , LGBMRegressor and lightgbm.Booster. * :ref:`library-lightning` - explain weights and predictions of lightning classifiers and regressors. diff --git a/eli5/lightgbm.py b/eli5/lightgbm.py index 02ea4411..31571948 100644 --- a/eli5/lightgbm.py +++ b/eli5/lightgbm.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division from collections import defaultdict -from typing import DefaultDict +from typing import DefaultDict, Any, Tuple import numpy as np # type: ignore import lightgbm # type: ignore @@ -17,7 +17,7 @@ all values sum to 1. """ - +@explain_weights.register(lightgbm.Booster) @explain_weights.register(lightgbm.LGBMClassifier) @explain_weights.register(lightgbm.LGBMRegressor) def explain_weights_lightgbm(lgb, @@ -32,13 +32,15 @@ def explain_weights_lightgbm(lgb, ): """ Return an explanation of an LightGBM estimator (via scikit-learn wrapper - LGBMClassifier or LGBMRegressor) as feature importances. + LGBMClassifier or LGBMRegressor, or via lightgbm.Booster) as feature importances. See :func:`eli5.explain_weights` for description of ``top``, ``feature_names``, ``feature_re`` and ``feature_filter`` parameters. - ``target_names`` and ``targets`` parameters are ignored. + ``target_names`` arguement is ignored for ``lightgbm.LGBMClassifer`` / ``lightgbm.LGBMRegressor``, + but used for ``lightgbm.Booster``. + ``target`` argument is ignored. Parameters ---------- @@ -51,8 +53,9 @@ def explain_weights_lightgbm(lgb, across all trees - 'weight' - the same as 'split', for compatibility with xgboost """ - coef = _get_lgb_feature_importances(lgb, importance_type) - lgb_feature_names = lgb.booster_.feature_name() + booster, is_regression = _check_booster_args(lgb) + coef = _get_lgb_feature_importances(booster, importance_type) + lgb_feature_names = booster.feature_name() return get_feature_importance_explanation(lgb, vec, coef, feature_names=feature_names, estimator_feature_names=lgb_feature_names, @@ -64,7 +67,7 @@ def explain_weights_lightgbm(lgb, is_regression=isinstance(lgb, lightgbm.LGBMRegressor), ) - +@explain_prediction.register(lightgbm.Booster) @explain_prediction.register(lightgbm.LGBMClassifier) @explain_prediction.register(lightgbm.LGBMRegressor) def explain_prediction_lightgbm( @@ -80,7 +83,7 @@ def explain_prediction_lightgbm( vectorized=False, ): """ Return an explanation of LightGBM prediction (via scikit-learn wrapper - LGBMClassifier or LGBMRegressor) as feature weights. + LGBMClassifier or LGBMRegressor, or via lightgbm.Booster) as feature weights. See :func:`eli5.explain_prediction` for description of ``top``, ``top_targets``, ``target_names``, ``targets``, @@ -108,20 +111,49 @@ def explain_prediction_lightgbm( Weights of all features sum to the output score of the estimator. """ - vec, feature_names = handle_vec(lgb, doc, vec, vectorized, feature_names) + booster, is_regression = _check_booster_args(lgb) + lgb_feature_names = booster.feature_name() + vec, feature_names = handle_vec(lgb, doc, vec, vectorized, feature_names, + num_features=len(lgb_feature_names)) if feature_names.bias_name is None: # LightGBM estimators do not have an intercept, but here we interpret # them as having an intercept feature_names.bias_name = '' X = get_X(doc, vec, vectorized=vectorized) + + if isinstance(lgb, lightgbm.Booster): + prediction = lgb.predict(X) + n_targets = prediction.shape[-1] + if is_regression is None and target_names is None: + # When n_targets is 1, this can be classification too. + # It's safer to assume regression in this case, + # unless users set it as a classification problem by assigning 'target_names' input [0,1] etc. + # If n_targets > 1, it must be classification. + is_regression = n_targets == 1 + elif is_regression is None: + is_regression = len(target_names) == 1 and n_targets == 1 + + if is_regression: + proba = None + else: + if n_targets == 1: + p, = prediction + proba = np.array([1 - p, p]) + else: + proba, = prediction + else: + proba = predict_proba(lgb, X) + n_targets = _lgb_n_targets(lgb) - proba = predict_proba(lgb, X) - weight_dicts = _get_prediction_feature_weights(lgb, X, _lgb_n_targets(lgb)) - x = get_X0(add_intercept(X)) + if is_regression: + names = ['y'] + elif isinstance(lgb, lightgbm.Booster): + names = np.arange(max(2, n_targets)) + else: + names = lgb.classes_ - is_regression = isinstance(lgb, lightgbm.LGBMRegressor) - is_multiclass = _lgb_n_targets(lgb) > 2 - names = lgb.classes_ if not is_regression else ['y'] + weight_dicts = _get_prediction_feature_weights(booster, X, n_targets) + x = get_X0(add_intercept(X)) def get_score_weights(_label_id): _weights = _target_feature_weights( @@ -145,22 +177,38 @@ def get_score_weights(_label_id): targets=targets, top_targets=top_targets, is_regression=is_regression, - is_multiclass=is_multiclass, + is_multiclass=n_targets > 1, proba=proba, get_score_weights=get_score_weights, ) - +def _check_booster_args(lgb, is_regression=None): + # type: (Any, bool) -> Tuple[lightgbm.Booster, bool] + if isinstance(lgb, lightgbm.Booster): + booster = lgb + else: + booster = lgb.booster_ + _is_regression = isinstance(lgb, lightgbm.LGBMRegressor) + if is_regression is not None and is_regression != _is_regression: + raise ValueError( + 'Inconsistent is_regression={} passed. ' + 'You don\'t have to pass it when using scikit-learn API' + .format(is_regression)) + is_regression = _is_regression + return booster, is_regression + def _lgb_n_targets(lgb): if isinstance(lgb, lightgbm.LGBMClassifier): - return lgb.n_classes_ - else: + return 1 if lgb.n_classes_ == 2 else lgb.n_classes_ + elif isinstance(lgb, lightgbm.LGBMRegressor): return 1 + else: + raise TypeError -def _get_lgb_feature_importances(lgb, importance_type): +def _get_lgb_feature_importances(booster, importance_type): aliases = {'weight': 'split'} - coef = lgb.booster_.feature_importance( + coef = booster.feature_importance( importance_type=aliases.get(importance_type, importance_type) ) norm = coef.sum() @@ -237,17 +285,15 @@ def walk(tree, parent_id=-1): return leaf_index, split_index -def _get_prediction_feature_weights(lgb, X, n_targets): +def _get_prediction_feature_weights(booster, X, n_targets): """ Return a list of {feat_id: value} dicts with feature weights, following ideas from http://blog.datadive.net/interpreting-random-forests/ """ - if n_targets == 2: - n_targets = 1 - dump = lgb.booster_.dump_model() + dump = booster.dump_model() tree_info = dump['tree_info'] _compute_node_values(tree_info) - pred_leafs = lgb.booster_.predict(X, pred_leaf=True).reshape(-1, n_targets) + pred_leafs = booster.predict(X, pred_leaf=True).reshape(-1, n_targets) tree_info = np.array(tree_info).reshape(-1, n_targets) assert pred_leafs.shape == tree_info.shape diff --git a/tests/test_lightgbm.py b/tests/test_lightgbm.py index 46f720ee..61129dc9 100644 --- a/tests/test_lightgbm.py +++ b/tests/test_lightgbm.py @@ -7,9 +7,11 @@ import numpy as np from sklearn.feature_extraction.text import CountVectorizer +import lightgbm from lightgbm import LGBMClassifier, LGBMRegressor from eli5 import explain_weights, explain_prediction +from eli5.lightgbm import _check_booster_args, _lgb_n_targets from .test_sklearn_explain_weights import ( test_explain_tree_classifier as _check_rf_classifier, test_explain_random_forest_and_tree_feature_filter as _check_rf_feature_filter, @@ -18,6 +20,7 @@ ) from .test_sklearn_explain_prediction import ( assert_linear_regression_explained, + assert_trained_linear_regression_explained, test_explain_prediction_pandas as _check_explain_prediction_pandas, test_explain_clf_binary_iris as _check_binary_classifier, ) @@ -144,3 +147,130 @@ def test_explain_weights_feature_names_pandas(boston_train): res = explain_weights(reg, feature_names=numeric_feature_names) for expl in format_as_all(res, reg): assert 'zz12' in expl + + +def test_check_booster_args(): + x, y = np.random.random((10, 2)), np.random.randint(2, size=10) + regressor = LGBMRegressor(min_data=1).fit(x, y) + classifier = LGBMClassifier(min_data=1).fit(x, y) + + booster, is_regression = _check_booster_args(regressor) + assert is_regression == True + assert isinstance(booster, lightgbm.Booster) + _, is_regression = _check_booster_args(regressor, is_regression=True) + assert is_regression == True + _, is_regression = _check_booster_args(classifier) + assert is_regression == False + _, is_regression = _check_booster_args(classifier, is_regression=False) + assert is_regression == False + with pytest.raises(ValueError): + _check_booster_args(classifier, is_regression=True) + with pytest.raises(ValueError): + _check_booster_args(regressor, is_regression=False) + + booster = regressor.booster_ + _booster, is_regression = _check_booster_args(booster) + assert _booster is booster + assert is_regression is None + _, is_regression = _check_booster_args(booster, is_regression=True) + assert is_regression == True + + booster = classifier.booster_ + _booster, is_regression = _check_booster_args(booster) + assert _booster is booster + assert is_regression is None + _, is_regression = _check_booster_args(booster, is_regression=False) + assert is_regression == False + +def test_explain_lightgbm_booster(boston_train): + xs, ys, feature_names = boston_train + booster = lightgbm.train( + params={'objective': 'regression', 'verbose_eval': -1}, + train_set=lightgbm.Dataset(xs, label=ys), + ) + res = explain_weights(booster) + for expl in format_as_all(res, booster): + assert 'Column_12' in expl + res = explain_weights(booster, feature_names=feature_names) + for expl in format_as_all(res, booster): + assert 'LSTAT' in expl + +def test_explain_prediction_reg_booster(boston_train): + X, y, feature_names = boston_train + booster = lightgbm.train( + params={'objective': 'regression', 'verbose_eval': -1}, + train_set=lightgbm.Dataset(X, label=y), + ) + assert_trained_linear_regression_explained( + X[0], feature_names, booster, explain_prediction, + reg_has_intercept=True) + +def test_explain_prediction_booster_multitarget(newsgroups_train): + docs, ys, target_names = newsgroups_train + vec = CountVectorizer(stop_words='english', dtype=np.float64) + xs = vec.fit_transform(docs) + clf = lightgbm.train( + params={'objective': 'multiclass', 'verbose_eval': -1, 'max_depth': 2,'n_estimators':100, + 'min_child_samples':1, 'min_child_weight':1, + 'num_class': len(target_names)}, + train_set=lightgbm.Dataset(xs.toarray(), label=ys)) + + doc = 'computer graphics in space: a new religion' + res = explain_prediction(clf, doc, vec=vec, target_names=target_names) + format_as_all(res, clf) + check_targets_scores(res) + graphics_weights = res.targets[1].feature_weights + assert 'computer' in get_all_features(graphics_weights.pos) + religion_weights = res.targets[3].feature_weights + assert 'religion' in get_all_features(religion_weights.pos) + + top_target_res = explain_prediction(clf, doc, vec=vec, top_targets=2) + assert len(top_target_res.targets) == 2 + assert sorted(t.proba for t in top_target_res.targets) == sorted( + t.proba for t in res.targets)[-2:] + +def test_explain_prediction_booster_binary( + newsgroups_train_binary_big): + docs, ys, target_names = newsgroups_train_binary_big + vec = CountVectorizer(stop_words='english', dtype=np.float64) + xs = vec.fit_transform(docs) + explain_kwargs = {} + clf = lightgbm.train( + params={'objective': 'binary', 'verbose_eval': -1, 'max_depth': 2,'n_estimators':100, + 'min_child_samples':1, 'min_child_weight':1}, + train_set=lightgbm.Dataset(xs.toarray(), label=ys)) + + get_res = lambda **kwargs: explain_prediction( + clf, 'computer graphics in space: a sign of atheism', + vec=vec, target_names=target_names, **kwargs) + res = get_res() + for expl in format_as_all(res, clf, show_feature_values=True): + assert 'graphics' in expl + check_targets_scores(res) + weights = res.targets[0].feature_weights + pos_features = get_all_features(weights.pos) + neg_features = get_all_features(weights.neg) + assert 'graphics' in pos_features + assert 'computer' in pos_features + assert 'atheism' in neg_features + + flt_res = get_res(feature_re='gra') + flt_pos_features = get_all_features(flt_res.targets[0].feature_weights.pos) + assert 'graphics' in flt_pos_features + assert 'computer' not in flt_pos_features + +def test_lgb_n_targets(): + clf = LGBMClassifier(min_data=1) + clf.fit(np.array([[0], [1]]), np.array([0, 1])) + assert _lgb_n_targets(clf) == 1 + + clf = LGBMClassifier(min_data=1) + clf.fit(np.array([[0], [1], [2]]), np.array([0, 1, 2])) + assert _lgb_n_targets(clf) == 3 + + reg = LGBMRegressor(min_data=1) + reg.fit(np.array([[0], [1], [2]]), np.array([0, 1, 2])) + assert _lgb_n_targets(reg) == 1 + + with pytest.raises(TypeError): + _lgb_n_targets(object())