6

I have built a dashboard in streamlit where you can select a client_ID and have SHAP plots displayed (Waterfall and Force plot) to interpret the prediction of credit default for this client.

I also want to display a SHAP summary plot with the whole train dataset. The later does not change every time you make a new prediction, and takes a lot of time to plot, so I want to cache it. I guess the best approach would be to use st.cache but I have not been able to make it.

Here below is the code I have unsuccessfully tried in main.py: I first define the function of which I want to cache the output (fig), then I execute the output in st.pyplot. It works without the st.cache decorator, but as soon as I add it and rerun the app, the function summary_plot_all runs indefinitely

IN:

@st.cache    
def summary_plot_all():
    fig, axes = plt.subplots(nrows=1, ncols=1)
    shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values, 
    prep_train.columns, max_display=50)
    return fig
    
st.pyplot(summary_plot_all())

OUT (displayed in streamlit app)

Running summary_plot_all().

Does anyone know what's wrong or a better way of caching a plot in streamlit ?

version of packages:
streamlit==0.84.1, 
matplotlib==3.4.2, 
shap==0.39.0
vpvinc
  • 155
  • 2
  • 10

1 Answers1

1

Try

import matplotlib

@st.cache(hash_funcs={matplotlib.figure.Figure: lambda _: None})
def summary_plot_all():
    fig, axes = plt.subplots(nrows=1, ncols=1)
    shap.summary_plot(shapvs[1], prep_train.iloc[:, :-1].values, 
    prep_train.columns, max_display=50)
    return fig

Check this streamlit github issue

Ailurophile
  • 2,552
  • 7
  • 21
  • 46