Skip to content

Commit

Permalink
fix erroranalysis test failures due to new shap release having incons…
Browse files Browse the repository at this point in the history
…istent dimensions for single valued target
  • Loading branch information
imatiach-msft committed Apr 5, 2024
1 parent 4bb3835 commit bf4420a
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions erroranalysis/erroranalysis/error_correlation_methods/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

def compute_gbm_global_importance(input_data, diff, model_task,
categorical_indexes):
"""Compute global importance score for EBM between the features and error.
:param input_data: The input data to compute the EBM global importance
"""Compute global importance score for GBM between the features and error.
:param input_data: The input data to compute the GBM global importance
score on.
:type input_data: numpy.ndarray
:param diff: The difference between the label and prediction
columns.
:type diff: numpy.ndarray
:param model_task: The model task.
:type model_task: str
:return: The computed EBM global importance score between the features and
:return: The computed GBM global importance score between the features and
error.
:rtype: list[float]
"""
Expand All @@ -33,6 +33,11 @@ def compute_gbm_global_importance(input_data, diff, model_task,
model.fit(input_data, diff, categorical_feature=categorical_indexes)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(input_data)
dims = np.shape(shap_values)
# fix some inconsistencies in the shape of the shap_values
# for newer versions of shap>=0.45.0 for single-valued target column
if is_classification and len(dims) == 2:
shap_values = np.expand_dims(shap_values, axis=0)
shap_mean_abs = np.abs(shap_values).mean(axis=0)
if is_classification:
shap_mean_abs = shap_mean_abs.mean(axis=0)
Expand Down

0 comments on commit bf4420a

Please sign in to comment.