Search code examples
pandasmatplotlibseabornmplcursorsstripplot

Hovering Annotations on Seaborn Stripplot


I am currently trying to program a seaborn stripplot which shows a point's column and index in the dataframe when hovered on by the mouse. This raises a few questions:

  • What does stripplot.contains() return?

I get that it returns a boolean saying whether the event lies in the container-artist and a dict giving the labels of the picked data points. But what does this dict actually look like in the case of a 2D DataFrame?

  • How can I locate a data point in my dataframe thanks to this data?

Thank you for your help.

My current program looks like follows, and is largely taken from this issue:

#Creating the data frame
A = pd.DataFrame(data = np.random.randint(10,size=(5,5)))

fig,ax = plt.subplots()

strp = sns.stripplot(A)

#Creating an empty annotation
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"))
annot.set_visible(False)

#Updating the annotation based on what stripplot.contains() returns
def update_annot(ind):
    mX, mY = A[ind["ind"][0]], A[ind["ind"][0]].loc[ [ind["ind"][0]] ]
    annot.xy = (mX,mY)

    annot.set_text(str(ind["ind"][0])+", "+str(ind["ind"][0]))
    annot.get_bbox_patch().set_facecolor("blue")
    annot.get_bbox_patch().set_alpha(0.4)
    
#Linking the annotation update to the event
def hover(event):
    vis = annot.get_visible()
    #Create the proper annotation if the event occurs within the bounds of the axis
    if event.inaxes == ax: 
        cont, ind = strp.contains(event)
        if cont:
            update_annot(ind)
            annot.set_visible(True)
            fig.canvas.draw_idle()
        else:
            if vis:
                annot.set_visible(False)
                fig.canvas.draw_idle()

#Call the hover function when the mouse moves
fig.canvas.mpl_connect("motion_notify_event", hover)

plt.show()


My guess here is that there is a problem with what the shape I am expecting from the dict that stripplot.contains outputs. And since I can not find a way to print it (once the last line is run, nothing print anymore), it is hard for me to know...

Thank you!


Solution

  • Unlike similar matplotlib functions, sns.stripplot(...) returns the ax on which the plot has been created. As such, strp.contains(...) is the same as ax.contains(...).

    To add annotations to plots, mplcursors is a handy library. The sel parameter in its annotation function has following fields:

    • sel.target: the x and y position of the element under the cursor
    • sel.artist: the matplotlib element under the cursor; in the case of stripplot, this is a collection of dots grouped in a PathCollection. There is one such collection per x-value.
    • sel.index: the index into the selected artist.

    The example code below is tested with seaborn 0.13.2 and matplotlib 3.8.3, starting from Seaborn's tips dataset.

    collection_to_day is a dictionary that maps each of the collection to its corresponding x-value. Adapting it to your specific situation might need some tweaks if the x-values aren't of the pd.Categorical type.

    renumbering is a dictionary that contains an array for each "day". That array maps the index of the dots to the index in the original dataframe. The original dataframe should not contain NaN values for the x or y values of the plot, as those will be filtered out.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import mplcursors
    
    def show_annotation(sel):
        day = collection_to_day[sel.artist]
        index_in_df = renumbering[day][sel.index]
        row = tips.iloc[index_in_df]
        txt = f"Day: {row['day']}\nTime: {row['time']}\nTip: {row['tip']}\nBill total: {row['total_bill']}"
        sel.annotation.set_text(txt)
    
    tips = sns.load_dataset('tips')
    
    fig, ax = plt.subplots()
    sns.stripplot(data=tips, x='day', y='tip', hue='time', palette='turbo', ax=ax)
    
    days = tips['day'].cat.categories
    collection_to_day = dict(zip(ax.collections, days))
    renumbering = dict()
    for day in days:
        renumbering[day] = tips[tips['day'] == day].reset_index()['index']
    
    cursor = mplcursors.cursor(hover=True)
    cursor.connect('add', show_annotation)
    plt.show()
    

    sns.stripplot with annotations