0

I computed several shap values for my Neural Net and wanted to plot them as a bar plot that only shows the top 10 most important features as bars and sums up the importance of the rest in another bar.

As far as I understood, this should be possible using shap.plots.bar().

However, whenever I try to run the code, I get the following error:

AssertionError: You must pass an Explanation object, Cohorts object, or dictionary to bar plot!

Next thing I did, was to try using shap.summary_plot( ..., plot_type="bar") since that is another way of displaying shap values in a bar chart. This indeed worked for me, however this does not sum up features in one bar.

So my question is, what did I do wrong while using shap.plots.bar() or what can I do to get shap.summary_plot( ..., plot_type="bar") to sum up features in one bar?

Here is my code:

explainer = shap.KernelExplainer(model=agent.policy.predict, data=state_df, link="identity")
shap_values = explainer.shap_values(X = state_df.iloc[0:35,:])

shap.summary_plot(shap_values = shap_values[0],features = state_df.iloc[0:35,:], plot_type="bar")
shap.plots.bar(shap_values[0], max_display=10)

Note that my background data set has 35 samples and that I have 160 inputs and 8 outputs, so the shape of my inputs state_df is (35, 160) and of my outputs action_df is (35, 8). Also whithin that code I am trying to display the shap values for the first output which is why I am using shap_values[0].

Hope someone can help :)

jakoebly
  • 1
  • 1

2 Answers2

0

had the same issue. using shap.plot_summary(..., plot_type="bar") works for me -

x = np.array(x_train[sample_indices], dtype=np.float32)
x_tensor = torch.from_numpy(x).to(DEVICE)
e = shap.DeepExplainer(model, x_tensor)

shap_values = e.shap_values(x_tensor)

shap.summary_plot(
    shap_values, features=x_samples, feature_names=x_cols, plot_type="bar", max_display=30)

enter image description here

Avi Avidan
  • 866
  • 8
  • 18
0

I had the same issues. You can resolve this with the following:

X = state_df.iloc[0:35,:]
shap_values = explainer(X)