Search code examples
pythonplotlyplotly-dash

Dash Annotating Lineplot Dynamically Between Subplots


I have a dataset which is similar to below one. Please note that there are multiple values for a single ID.

import pandas as pd
import numpy as np
import random

df = pd.DataFrame({'DATE_TIME':pd.date_range('2022-11-01', '2022-11-05 23:00:00',freq='h'),
                   'SBP':[random.uniform(110, 160) for n in range(120)],
                   'DBP':[random.uniform(60, 100) for n in range(120)],
                   'ID':[random.randrange(1, 100) for n in range(120)],
                   'TIMEINTERVAL':[random.randrange(1, 200) for n in range(120)]})

df['VISIT'] = df['DATE_TIME'].dt.day

df['MODE'] = np.select([df['VISIT']==1, df['VISIT'].isin([2,3])], ['New', 'InProgress'], 'Done')

I use the following DASH code to make slider:

app = Dash(__name__)


app.layout = html.Div([
    html.H4('Interactive Scatter Plot with ABPM dataset'),
    dcc.Graph(id="scatter-plot"),
    html.P("Filter by time interval:"),
    dcc.Dropdown(df.ID.unique(), id='pandas-dropdown-1'), # for choosing ID,
    dcc.RangeSlider(
        id='range-slider',
        min=0, max=600, step=10,
        marks={0: '0', 50: '50', 100: '100', 150: '150', 200: '200', 250: '250', 300: '300', 350: '350', 400: '400', 450: '450', 500: '500', 550: '550', 600: '600'},
        value=[0, 600]
    ),
    html.Div(id='dd-output-container')
])


@app.callback(
    Output("scatter-plot", "figure"),
    Input("pandas-dropdown-1", "value"),
    Input("range-slider", "value"),
    prevent_initial_call=True)

def update_bar_chart(value,slider_range):
    low, high = slider_range
    df1 = df.query("ID == @value & TIMEINTERVAL > @low & TIMEINTERVAL < @high").copy() 
    
    if df1.shape[0] != 0:
        fig = px.scatter(df1, x="DATE_TIME", y=["SBP","DBP"],
                         hover_data=['TIMEINTERVAL'],facet_col='VISIT',
                         facet_col_wrap=2,
                         symbol='MODE')
        
        fig.update_xaxes(matches= None, showticklabels=True)

        return fig
    else: 
        return dash.no_update


app.run_server(debug=True, use_reloader=False)

If df1 Visit column has value more than 1, then I would like to annotate subplots with arrow to articulate reading. To do so, I wrote the followings cript in update_bar_charts function, but it did not compile.

def update_bar_chart(value,slider_range):
    low, high = slider_range
    df1 = df.query("ID == @value & TIMEINTERVAL > @low & TIMEINTERVAL < @high").copy() 
    
    if df1.shape[0] != 0:
        fig = px.scatter(df1, x="DATE_TIME", y=["SBP","DBP"],
                         hover_data=['TIMEINTERVAL'],facet_col='VISIT',
                         facet_col_wrap=2,
                         symbol='MODE')
        
        fig.update_xaxes(matches= None, showticklabels=True)

                if df1.VISIT!=1:
                fig.add_annotation(
                xref="x domain",
                yref="y domain",
                # The arrow head will be 25% along the x axis, starting from the left
                x=0.25,
                # The arrow head will be 40% along the y axis, starting from the bottom
                y=0.4,
                arrowhead=2,
            )

                   return fig
    else: 
        return dash.no_update


app.run_server(debug=True, use_reloader=False)

What I have isenter image description here:

Whan I want to achieve is:

enter image description here

How can I add those arrows to make reading the plots easier? Number of arrows should change dynamically because each ID has different number of visits.


Solution

  • Add annotations looping through rows of subplots. Use a 'x/y* domain' value of the 'xref'/'yref' property of an annotation, to specify a coordinate as a ratio to the x/y domain(the width/height of the frame of a subplot). Also use the 'ax'/'ay' property to specify the starting point of an arrow.

    This is an example.

    n_plots = len(df1['VISIT'].unique())
    n_rows = (n_plots+1)//2
    row_spacing = 1/((1/0.5+1) * n_rows - 1) # 50% of y domain
    col_spacing = 0.1
    col_spacing_in_x_domain = 1/((1/col_spacing-1)/2)
    row_spacing_in_y_domain = 1/((1/row_spacing+1)/n_rows - 1)
    
    fig = px.scatter(df1,
        facet_col='VISIT',
        facet_col_wrap=2,
        facet_row_spacing=row_spacing, facet_col_spacing=col_spacing,
        ...
    )
    fig.update_xaxes(matches= None, showticklabels=True)
    
    for i in range(n_rows):
        # A row number 1 is the bottom one.
        trace = next(fig.select_traces(row=n_rows-i, col=1))
        xref, yref = trace.xaxis + ' domain', trace.yaxis + ' domain'
        if i*2+1 < n_plots:
            fig.add_annotation(
                xref=xref, yref=yref, axref=xref, ayref=yref,
                ax=1, ay=0.5,
                x=1 + col_spacing_in_x_domain, y=0.5,
                arrowhead = 2,
            )
        if i*2+2 < n_plots:
            fig.add_annotation(
                xref=xref, yref=yref, axref=xref, ayref=yref,
                ax=1 + col_spacing_in_x_domain, ay=0.5,
                x=1, y=-row_spacing_in_y_domain - 0.5,
                arrowhead = 2,
            )