Search code examples
pythonplotlydata-visualizationword2vecplotly-python

Plotly - Highlight data point and nearest three points on hover


I have made a scatter plot of the word2vec model using plotly.
I want functionality of highlighting the specific data point on hover along with the top 3 nearest vectors to that. It would be of great help if anyone can guide me with this or suggest any other option

model
csv

Code:

import gensim
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
import plotly.express as px

def get_2d_coordinates(model, words):
    arr = np.empty((0,100), dtype='f')
    labels = []
    for wrd_score in words:
        try:
            wrd_vector = model.wv.get_vector(wrd_score)
            arr = np.append(arr, np.array([wrd_vector]), axis=0)
            labels.append(wrd_score)
        except:
            pass
    tsne = TSNE(n_components=2, random_state=0)
    np.set_printoptions(suppress=True)
    Y = tsne.fit_transform(arr)
    x_coords = Y[:, 0]
    y_coords = Y[:, 1]
    return x_coords, y_coords

ic_model = gensim.models.Word2Vec.load("w2v_IceCream.model")
ic = pd.read_csv('ic_prods.csv')

icx, icy = get_2d_coordinates(ic_model, ic['ITEM_DESC'])
ic_data = {'Category': ic['SUB_CATEGORY'],
            'Words':ic['ITEM_DESC'],
            'X':icx,
            'Y':icy}
ic_df = pd.DataFrame(ic_data)
ic_df.head()
ic_fig = px.scatter(ic_df, x=icx, y=icy, color=ic_df['Category'], hover_name=ic_df['Words'], title='IceCream Data')
ic_fig.show()

enter image description here


Solution

  • In plotly-python, I don't think there's an easy way of retrieving the location of the cursor. You can attempt to use go.FigureWidget to highlight a trace as described in this answer, but i think you're going to be limited with with plotly-python and i'm not sure if highlighting the closest n points will be possible.

    However, I believe that you can accomplish what you want in plotly-dash since callbacks are supported - meaning you would be able to retrieve location of your cursor and then calculate the n closest data points to your cursor and highlight the data points as needed.

    Below is an example of such a solution. If you haven't seen it before, it looks complicated, but what is happening is that I am taking the point where you clicked as an input. plotly is plotly.js under the hood so it comes us in the form of a dictionary (and not some kind of plotly-python object). Then I calculate the closest three data points to the clicked input point by comparing the coordinates of every other point in the dataframe, add the information from the three closest points as traces to the input with the color teal (or any color of your choosing), and send this modified input back as the output, and update the figure.

    I am using click instead of hover because hover would cause the highlighted points to flicker too much as you drag your mouse through the points.

    Also the dash app doesn't work perfectly as I believe there is some issue when you double click on points (you can see me click once in the gif below before getting it to start working), but this basic framework is hopefully close enough to what you want. Cheers!

    import gensim
    import numpy as np
    import pandas as pd
    from sklearn.manifold import TSNE
    import plotly.express as px
    import plotly.graph_objects as go
    
    import json
    
    import dash
    from dash import dcc, html, Input, Output
    
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
    
    
    def get_2d_coordinates(model, words):
        arr = np.empty((0,100), dtype='f')
        labels = []
        for wrd_score in words:
            try:
                wrd_vector = model.wv.get_vector(wrd_score)
                arr = np.append(arr, np.array([wrd_vector]), axis=0)
                labels.append(wrd_score)
            except:
                pass
        tsne = TSNE(n_components=2, random_state=0)
        np.set_printoptions(suppress=True)
        Y = tsne.fit_transform(arr)
        x_coords = Y[:, 0]
        y_coords = Y[:, 1]
        return x_coords, y_coords
    
    ic_model = gensim.models.Word2Vec.load("w2v_IceCream.model")
    ic = pd.read_csv('ic_prods.csv')
    
    icx, icy = get_2d_coordinates(ic_model, ic['ITEM_DESC'])
    ic_data = {'Category': ic['SUB_CATEGORY'],
                'Words':ic['ITEM_DESC'],
                'X':icx,
                'Y':icy}
    
    ic_df = pd.DataFrame(ic_data)
    ic_fig = px.scatter(ic_df, x=icx, y=icy, color=ic_df['Category'], hover_name=ic_df['Words'], title='IceCream Data')
    
    NUMBER_OF_TRACES = len(ic_df['Category'].unique())
    ic_fig.update_layout(clickmode='event+select')
    
    app.layout = html.Div([
        dcc.Graph(
            id='ic_figure',
            figure=ic_fig)
        ])
    
    ## we take the 4 closest points because the 1st closest point will be the point itself
    def get_n_closest_points(x0, y0, df=ic_df[['X','Y']].copy(), n=4):
    
        """we can save some computation time by looking for the smallest distance^2 instead of distance"""
        """distance = sqrt[(x1-x0)^2 + (y1-y0)^2]"""
        """distance^2 = [(x1-x0)^2 + (y1-y0)^2]"""
        
        df["dist"] = (df["X"]-x0)**2 + (df["Y"]-y0)**2
    
        ## we don't return the point itself which will always be closest to itself
        return df.sort_values(by="dist")[1:n][["X","Y"]].values
    
    @app.callback(
        Output('ic_figure', 'figure'),
        [Input('ic_figure', 'clickData'),
        Input('ic_figure', 'figure')]
        )
    def display_hover_data(clickData, figure):
        print(clickData)
        if clickData is None:
            # print("nothing was clicked")
            return figure
        else:
            hover_x, hover_y = clickData['points'][0]['x'], clickData['points'][0]['y']
            closest_points = get_n_closest_points(hover_x, hover_y)
    
            ## this means that this function has ALREADY added another trace, so we reduce the number of traces down the original number
            if len(figure['data']) > NUMBER_OF_TRACES:
                # print(f'reducing the number of traces to {NUMBER_OF_TRACES}')
                figure['data'] = figure['data'][:NUMBER_OF_TRACES]
                # print(figure['data'])
            
            new_traces = [{
                'marker': {'color': 'teal', 'symbol': 'circle'},
                'mode': 'markers',
                'orientation': 'v',
                'showlegend': False,
                'x': [x],
                'xaxis': 'x',
                'y': [y],
                'yaxis': 'y',
                'type': 'scatter',
                'selectedpoints': [0]
            } for x,y in closest_points]
    
            figure['data'].extend(new_traces)
            # print("after\n")
            # print(figure['data'])
            return figure
    
    if __name__ == '__main__':
        app.run_server(debug=True)
    

    enter image description here