16

I'm currently working on a classification problem and want to create visualizations of feature importance. I use the Python XGBoost package which already provides feature importance plots. However, I found shap (https://github.com/slundberg/shap), a Python library that creates very nice plots for feature importance based on tree classifiers. Everything works fine, I can also save the created plots as PNG, however, if I try to save it as PDF or SVG, I get an exception. Here is what I am doing:

First I train the XGBoost model and get the model back denoted by bst.

train = remove_labels_for_binary_df(dataset_fc_baseline_1[0].train)
test = remove_labels_for_binary_df(dataset_fc_baseline_1[0].test)
results, bst = xgboost_with_bst(*transform_feat_to_num(train, test))

Then I create the shap values, use these to create a summary plot and save the create visualization. Everything works fine if I save the plot as plt.savefig('shap.png').

import shap
import matplotlib.pyplot as plt

shap.initjs()

explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(train)
fig = shap.summary_plot(shap_values, train, show=False)
plt.savefig('shap.png')

However, I need PDF or SVG plots instead of png and therefore tried to save it with plt.savefig('shap.pdf') which normally works fine, but produces the following exception for the shap plot.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-39-49d17973f438> in <module>()
  1 fig = shap.summary_plot(shap_values, train, show=False)
----> 2 plt.savefig('shap.pdf')

 C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\pyplot.py in 
savefig(*args, **kwargs)
708 def savefig(*args, **kwargs):
709     fig = gcf()
--> 710     res = fig.savefig(*args, **kwargs)
711     fig.canvas.draw_idle()   # need this if 'transparent=True' to reset 
colors
712     return res

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in 
savefig(self, fname, **kwargs)
2033             self.set_frameon(frameon)
2034 
-> 2035         self.canvas.print_figure(fname, **kwargs)
2036 
2037         if frameon:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\backend_bases.py in 
print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, 
**kwargs)
2261                 orientation=orientation,
2262                 bbox_inches_restore=_bbox_inches_restore,
-> 2263                 **kwargs)
2264         finally:
2265             if bbox_inches and restore_bbox:

C:\Users\Studio\Anaconda3\lib\site- 
packages\matplotlib\backends\backend_pdf.py in print_pdf(self, filename, 
**kwargs)
2584                 RendererPdf(file, image_dpi, height, width),
2585                 bbox_inches_restore=_bbox_inches_restore)
-> 2586             self.figure.draw(renderer)
2587             renderer.finalize()
2588             if not isinstance(filename, PdfPages):

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 53                 renderer.start_filter()
 54 
---> 55             return draw(artist, renderer, *args, **kwargs)
 56         finally:
 57             if artist.get_agg_filter() is not None:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\figure.py in 
draw(self, renderer)
1473 
1474             mimage._draw_list_compositing_images(
-> 1475                 renderer, self, artists, self.suppressComposite)
1476 
1477             renderer.close_group('figure')

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in 
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139     if not_composite or not has_images:
140         for a in artists:
--> 141             a.draw(renderer)
142     else:
143         # Composite any adjacent images together

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 53                 renderer.start_filter()
 54 
---> 55             return draw(artist, renderer, *args, **kwargs)
 56         finally:
 57             if artist.get_agg_filter() is not None:

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\axes\_base.py in 
draw(self, renderer, inframe)
2605             renderer.stop_rasterizing()
2606 
-> 2607         mimage._draw_list_compositing_images(renderer, self, 
 artists)
2608 
2609         renderer.close_group('axes')

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\image.py in 
_draw_list_compositing_images(renderer, parent, artists, suppress_composite)
139     if not_composite or not has_images:
140         for a in artists:
--> 141             a.draw(renderer)
142     else:
143         # Composite any adjacent images together

C:\Users\Studio\Anaconda3\lib\site-packages\matplotlib\artist.py in 
draw_wrapper(artist, renderer, *args, **kwargs)
 58                 renderer.stop_filter(artist.get_agg_filter())
 59             if artist.get_rasterized():
---> 60                 renderer.stop_rasterizing()
 61 
 62     draw_wrapper._supports_rasterization = True

C:\Users\Studio\Anaconda3\lib\site- 
packages\matplotlib\backends\backend_mixed.py in stop_rasterizing(self)
128 
129             height = self._height * self.dpi
--> 130             buffer, bounds = 
self._raster_renderer.tostring_rgba_minimized()
131             l, b, w, h = bounds
132             if w > 0 and h > 0:

C:\Users\Studio\Anaconda3\lib\site- 
packages\matplotlib\backends\backend_agg.py in tostring_rgba_minimized(self)
138                 [extents[0] + extents[2], self.height - extents[1]]]
139         region = self.copy_from_bbox(bbox)
--> 140         return np.array(region), extents
141 
142     def draw_path(self, gc, path, transform, rgbFace=None):

ValueError: negative dimensions are not allowed

Do you have any idea how to fix this?

tdy
  • 36,675
  • 19
  • 86
  • 83
Roqua
  • 161
  • 1
  • 1
  • 4

7 Answers7

7

While saving the plot one has to append matplotlib=True,show=False:

def heart_disease_risk_factors(model, patient):

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(patient)
    shap.initjs()

    return shap.force_plot(explainer.expected_value[1],shap_values[1],\
        patient,matplotlib=True,show=False)


plt.clf()
data_for_prediction = X_test.iloc[2,:].astype(float)
heart_disease_risk_factors(model, data_for_prediction)
plt.savefig("gg.png",dpi=150, bbox_inches='tight')
jtlz2
  • 7,700
  • 9
  • 64
  • 114
7

By default summary_plot calls plt.show() to ensure the plot displays. But if you pass show=False to summary_plot then it will allow you to save it. e.g.

#shap summary plot plotting
import matplotlib.pyplot as pl
shap.summary_plot(shap_values, X_train,max_display=10,show=False)
pl.savefig("shap_summary.svg",dpi=700) #.png,.pdf will also support here
pyplot.show()
Snehal Rajput
  • 335
  • 3
  • 6
  • This suggestion also works for `shap.waterfall_plot` for explaining model prediction for any specific sample. – Heelara Oct 18 '22 at 05:46
4

This is an issue between NumPy and matplotlib caused when plotting with rasterized=True (which shap does if there are more than 500 datapoints) and has been resolved in the latest version of matplotlib.

2

I think the easiest way is:

shap.summary_plot(shap_values, X, show=False)
plt.savefig('mygraph.pdf', format='pdf', dpi=600, bbox_inches='tight')
plt.show()
  • Your answer could be improved with additional supporting information. Please [edit] to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers in the [help center](https://stackoverflow.com/help/how-to-answer). – Ethan Jun 17 '22 at 18:53
2

The easiest way is to save as follows:

 fig = shap.summary_plot(shap_values, X_test, plot_type="bar", feature_names=["a", "b"], show=False)
plt.savefig("trial.png")

Note: By default summary_plot calls plt.show() to ensure the plot displays.But if you pass show=False to summary_plot then it won't

https://github.com/slundberg/shap/issues/153

Amit Tiwari
  • 163
  • 2
  • 10
-1

Please try this:

shap.plots.force(shape_values[0], show=False, matplotlib=True).savefig('shap.pdf')
ah bon
  • 9,293
  • 12
  • 65
  • 148
-2

saving as pdf:

plt.savefig("shap.pdf", format='pdf', dpi=1000, bbox_inches='tight')

saving as eps:

plt.savefig("shap.eps", format='eps', dpi=1000, bbox_inches='tight')

for more information:

matplotlib.pyplot.savefig matplotlib

check the link to learn more, for example what is the meaning of bbox_inches='tight'.

ah bon
  • 9,293
  • 12
  • 65
  • 148