Skip to content

Commit

Permalink
🖍️ improving docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasvdd committed May 2, 2024
1 parent c966a88 commit f694cd6
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions powershap/shap_wrappers/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,15 @@ def explain(
Shap_values = np.abs(Shap_values)

if len(np.shape(Shap_values)) > 2:
# SHAPE: (n_samples, n_features, n_outputs)
assert len(np.shape(Shap_values)) == 3, "Shap values should be 3D"
# in case of multi-output, we take the max of the outputs as the shap value
Shap_values = np.max(Shap_values, axis=-1)
# new shape = (n_samples, n_features)

# TODO: consider to convert to even float16?
Shap_values = np.mean(Shap_values, axis=0).astype("float32")
# new shape = (n_features,)

shaps += [Shap_values]

Expand Down

0 comments on commit f694cd6

Please sign in to comment.