Search code examples
pythonplotlyhover

Plotly go: how to add an image to the hover feature?


The first segment of code below (code # 1) generates a graph for which 1) when you hover over each point, the data associated with each point is displayed and 2) when you click on each point, the data associated with each point is saved to a list. For this code, I would also like to display the image associated with each point. Assume the dataframe df has a column 'image' which contains the image pixel/array data of each point. I found code online (code #2) that implements this image hover feature but without the click feature. I'm having a hard time combining the image hover feature with the click feature. So, basically, I'm trying to combine the click feature (click on point, it's data is saved to a list) of code # 2 into code # 1.

CODE # 1 (with click feature):

import json
from textwrap import dedent as d
import pandas as pd
import plotly.graph_objects as go
import numpy as np
import dash
from dash import dcc
import dash_html_components as html
import plotly.express as px
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

# app info
app = JupyterDash(__name__)
styles = {
    'pre': {
        'border': 'thin lightgrey solid',
        'overflowX': 'scroll'
    }
}

# data
df = px.data.gapminder().query("continent=='Oceania'")

# plotly figure
fig = px.line(df, x="year", y="lifeExp", color="country", title="No label selected")
fig.update_traces(mode="markers+lines")

app.layout = html.Div([
    dcc.Graph(
        id='figure1',
        figure=fig,
    ),

    html.Div(className
             ='row', children=[
        html.Div([
            dcc.Markdown(d("""Hoverdata using figure references""")),
            html.Pre(id='hoverdata2', style=styles['pre']),
        ], className='three columns'),
                 
                     html.Div([
            dcc.Markdown(d("""
              
              Full hoverdata
            """)),
            html.Pre(id='hoverdata1', style=styles['pre']),
        ], className='three columns')   
    ]),
    
])

# container for clicked points in callbacks
store = []

@app.callback(
    Output('figure1', 'figure'),
    Output('hoverdata1', 'children'),
    Output('hoverdata2', 'children'),
    [Input('figure1', 'clickData')])
def display_hover_data(hoverData):
    
    if hoverData is not None:
        traceref = hoverData['points'][0]['curveNumber']
        pointref = hoverData['points'][0]['pointNumber']
        store.append([fig.data[traceref]['name'],
                      fig.data[traceref]['x'][pointref],
                     fig.data[traceref]['y'][pointref]])
        fig.update_layout(title = 'Last label was ' + fig.data[traceref]['name'])
        return fig, json.dumps(hoverData, indent=2), str(store)
    else:
        return fig, 'None selected', 'None selected'

app.run_server(mode='external', port = 7077, dev_tools_ui=True,
          dev_tools_hot_reload =True, threaded=True)

CODE # 2 (includes image hover feature):

from jupyter_dash import JupyterDash
from dash import Dash, dcc, html, Input, Output, no_update
import plotly.graph_objects as go
import pandas as pd

## create sample random data
df = pd.DataFrame({
    'x': [1,2,3],
    'y': [2,3,4],
    'z': [3,4,5],
    'color': ['red','green','blue'],
    'img_url': [
        "https://upload.wikimedia.org/wikipedia/commons/thumb/0/02/Stack_Overflow_logo.svg/2880px-Stack_Overflow_logo.svg.png",
        "https://upload.wikimedia.org/wikipedia/commons/3/37/Plotly-logo-01-square.png",
        "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ed/Pandas_logo.svg/2880px-Pandas_logo.svg.png"
    ]
})

fig = go.Figure(data=[
    go.Scatter3d(
        x=df['x'], 
        y=df['y'], 
        z=df['z'],
        mode='markers',
        marker=dict(color=df['color'])
    )
])

# turn off native plotly.js hover effects - make sure to use
# hoverinfo="none" rather than "skip" which also halts events.
fig.update_traces(hoverinfo="none", hovertemplate=None)
fig.update_layout(
    scene = dict(
        xaxis = dict(range=[-1,8],),
                     yaxis = dict(range=[-1,8],),
                     zaxis = dict(range=[-1,8],),
    ),
)

app = JupyterDash(__name__)

server = app.server

app.layout = html.Div([
    dcc.Graph(id="graph-basic-2", figure=fig, clear_on_unhover=True),
    dcc.Tooltip(id="graph-tooltip"),
])


@app.callback(
    Output("graph-tooltip", "show"),
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("graph-basic-2", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    # demo only shows the first point, but other points may also be available
    pt = hoverData["points"][0]
    bbox = pt["bbox"]
    num = pt["pointNumber"]

    df_row = df.iloc[num]
    img_src = df_row['img_url']

    children = [
        html.Div([
            html.Img(src=img_src, style={"width": "100%"}),
        ], style={'width': '100px', 'white-space': 'normal'})
    ]

    return True, bbox, children

app.run_server(mode="inline")

Solution

    • you want a callback that does hover and click
      • on hover display image associated with point and full hover info
      • on click update list of clicked points and figure title
    • Assume the dataframe df has a column 'image' have created one that is a b64 encoded image
    • have inserted this into the figure by using customdata (hover_data parameter in px)
    • have added an additional div image
    • have changed callback to behave as it did before and also contents on new div. This uses b64 encoded image, extending with necessary "data:image/png;base64,"
    • need to take note of this https://dash.plotly.com/vtk/click-hover and https://dash.plotly.com/advanced-callbacks
    import json
    from textwrap import dedent as d
    import pandas as pd
    import plotly.graph_objects as go
    import numpy as np
    import dash
    import plotly.express as px
    from dash.dependencies import Input, Output
    from jupyter_dash import JupyterDash
    import warnings
    import base64, io, requests
    from PIL import Image
    from pathlib import Path
    
    warnings.simplefilter(action="ignore", category=FutureWarning)
    
    # app info
    app = JupyterDash(__name__)
    styles = {"pre": {"border": "thin lightgrey solid", "overflowX": "scroll"}}
    
    # data for whare images can be found
    df_flag = pd.read_csv(
        io.StringIO(
            """country,Alpha-2 code,Alpha-3 code,URL
    Australia,AU,AUS,https://www.worldometers.info//img/flags/small/tn_as-flag.gif
    New Zealand,NZ,NZL,https://www.worldometers.info//img/flags/small/tn_nz-flag.gif"""
        )
    )
    
    # ensure that images exist on your file system...
    f = Path.cwd().joinpath("flags")
    if not f.exists():
        f.mkdir()
    
    # download some images and use easy to use filenames...
    for r in df_flag.iterrows():
        flag_file = f.joinpath(f'{r[1]["Alpha-3 code"]}.gif')
        if not flag_file.exists():
            r = requests.get(r[1]["URL"], stream=True, headers={"User-Agent": "XY"})
            with open(flag_file, "wb") as fd:
                for chunk in r.iter_content(chunk_size=128):
                    fd.write(chunk)
    
    # encode
    def b64image(country):
        b = io.BytesIO()
        im = Image.open(Path.cwd().joinpath("flags").joinpath(f"{country}.gif"))
        im.save(b, format="PNG")
        b64 = base64.b64encode(b.getvalue())
        return b64.decode("utf-8")
    
    
    df_flag["image"] = df_flag["Alpha-3 code"].apply(b64image)
    
    # data
    df = px.data.gapminder().query("continent=='Oceania'")
    df = df.merge(df_flag, on="country")  # include URL and b64 encoded image
    
    # plotly figure.  Include URL and image columns in customdata by using hover_data
    fig = px.line(
        df,
        x="year",
        y="lifeExp",
        color="country",
        title="No label selected",
        hover_data={"URL": True, "image": False},
    )
    fig.update_traces(mode="markers+lines")
    
    app.layout = dash.html.Div(
        [
            dash.dcc.Graph(
                id="figure1",
                figure=fig,
            ),
            dash.html.Div(
                className="row",
                children=[
                    dash.html.Div(id="image"),
                    dash.html.Div(
                        [
                            dash.dcc.Markdown(d("""Hoverdata using figure references""")),
                            dash.html.Pre(id="hoverdata2", style=styles["pre"]),
                        ],
                        className="three columns",
                    ),
                    dash.html.Div(
                        [
                            dash.dcc.Markdown(
                                d(
                                    """
                  
                  Full hoverdata
                """
                                )
                            ),
                            dash.html.Pre(id="hoverdata1", style=styles["pre"]),
                        ],
                        className="three columns",
                    ),
                ],
            ),
        ]
    )
    
    # container for clicked points in callbacks
    store = []
    
    
    @app.callback(
        Output("figure1", "figure"),
        Output("hoverdata1", "children"),
        Output("hoverdata2", "children"),
        Output("image", "children"),
        [Input("figure1", "clickData"), Input("figure1", "hoverData")],
    )
    def display_hover_data(clickData, hoverData):
        # is it a click or hover event?
        ctx = dash.callback_context
    
        if ctx.triggered[0]["prop_id"] == "figure1.clickData":
            traceref = clickData["points"][0]["curveNumber"]
            pointref = clickData["points"][0]["pointNumber"]
            store.append(
                [
                    fig.data[traceref]["name"],
                    fig.data[traceref]["x"][pointref],
                    fig.data[traceref]["y"][pointref],
                ]
            )
            fig.update_layout(title="Last label was " + fig.data[traceref]["name"])
    
            return fig, dash.no_update, str(store), dash.no_update
        elif ctx.triggered[0]["prop_id"] == "figure1.hoverData":
            # simpler case of just use a URL...
            # dimg = dash.html.Img(src=hoverData["points"][0]["customdata"][0], style={"width": "30%"})
            # question wanted image encoded in dataframe....
            dimg = dash.html.Img(
                src="data:image/png;base64," + hoverData["points"][0]["customdata"][1],
                style={"width": "30%"},
            )
    
            return fig, json.dumps(hoverData, indent=2), dash.no_update, dimg
        else:
            return fig, "None selected", "None selected", "no image"
    
    
    # app.run_server(mode='external', port = 7077, dev_tools_ui=True,
    #           dev_tools_hot_reload =True, threaded=True)
    app.run_server(mode="inline")