2

I am trying to get the features which are important for a class and have a positive contribution (having red points on the positive side of the SHAP plot).

I can get the shap_values and plot the shap summary for each class (e.g. class 2 here) using the following code:

import shap 
explainer = shap.TreeExplainer(clf) 
shap_values = explainer.shap_values(X) 
shap.summary_plot(shap_values[2], X) 

From the plot I can understand which features are important to that class. In the below plot, I can say alcohol and sulphates are the main features (that I am more interested in).

shap summary plot

However, I want to automate this process, so the code can rank the features (which are important on the positive side) and return the top N. Any idea on how to automate this interpretation?

I need to automatically identify those important features for each class. Any other method rather than shap that can handle this process would be ideal.

Sali
  • 77
  • 1
  • 8

1 Answers1

0

You can do the following steps - where basically we are trying to get only the values that effect the classification positively (shap_values>0) when shap_values<0 it means don't classify Later you take mean and sort the results. If you prefers the global values then use .abs() instead of [shap_df>0] and for the hole model use only shap_values instead of shap_values['your_class_number']

import shap 
import pandas as pd 

explainer = shap.TreeExplainer(clf) 
shap_values = explainer.shap_values(X) 
shap_df = pd.DataFrame(shap_values['your_class_number'],columns=X.columns) 
    
feature_importance = (shap_df
                                    [shap_df>0]
                                    .mean()
                                    .sort_values(ascending=False)
                                    .reset_index()
                                    .rename(columns={'index':'feature',0:'weight'})
                                    .head(n)
                                )
sarielg
  • 1
  • 1