Search code examples
pythonstreamlitshap

Cannot display SHAP text visualization in Streamlit


I am trying to build a dashboard of my NLP project. So I am using a BERT model for the predictions, the SHAP package for the visualization and Streamlit for creating dashboard:

tokenizer = AutoTokenizer.from_pretrained(model_name_cla)
model = AutoModelForSequenceClassification.from_pretrained(model_name_cla)
labels = ['1- Tarife','2- Dateneingabe','3- Bestätigungsmail','4- Kundenbetreuung','5- Aufwand vom Vergleich bis Abschluss',
          '6- After-sales Wechselprozess','7 - Werbung/VX Kommunikation','8 - Sonstiges','9 - Nicht auswertbar']

def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128, truncation=True) for v in x])
    attention_mask = (tv!=0).type(torch.int64)
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val

text = ['This is just a test']

# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer, output_names=labels)

shap_values = explainer(text, fixed_context=1)

shap.plots.text(shap_values)

The code works fine in my jupyter notebook but when I am trying to execute this as .py file with streamlit nothing happens. It neither displays anything nor throwing an error. My console just returns this while execution:

<IPython.core.display.HTML object> >

How can I display my graph in streamlit?


Solution

  • This can be visualized with Streamlit Components and latest SHAP v0.36+ (which define a new getjs method), to plot JS SHAP plots

    (some plots like summary_plot are actually Matplotlib and can be plotted with st.pyplot)

    import streamlit as st
    import streamlit.components.v1 as components
    
    def st_shap(plot, height=None):
        shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
        components.html(shap_html, height=height)
    
    st_shap(shap.plots.text(shap_values),400)
    

    Find more detailed discussion on Visualizing Shap in streamlit here - Display SHAP diagrams with Streamlit