Search code examples
pythonplotlyplotly-dash

Include figure parameters as a callback option - Dash


I've got a callback function that changes a figure to various spatial maps. When hexbin is selected, I'm aiming to include parameters to alter the figure.

Is it possible to only insert these parameters when the hexbin option is selected?

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import dash_bootstrap_components as dbc
import plotly.express as px
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import plotly.figure_factory as ff


data = pd.DataFrame({
   'Cat': ['t','y','y','y','f','f','j','k','k','k','s','s','s','s'],
   'LAT': [5,6,4,5,4,7,8,9,5,6,18,17,15,16],
   'LON': [10,11,9,11,10,8,8,5,8,7,18,16,16,17],
   })

N = 30
data = pd.concat([data] * N, ignore_index=True)

data['Color'] = data['Cat'].map(dict(zip(data['Cat'].unique(), px.colors.qualitative.Plotly[:len(data['Cat'].unique())])))

Color = data['Color'].unique()

Type_Category = data['Cat'].unique()
Type_cats = dict(zip(Type_Category, Color))


external_stylesheets = [dbc.themes.SPACELAB, dbc.icons.BOOTSTRAP]

app = dash.Dash(__name__, external_stylesheets = external_stylesheets)

filtering = html.Div(children=[
    html.Div(children=[
        html.Label('Cats', style = {'paddingTop': '2rem', 'display': 'inline-block'}),
        dcc.Checklist(
            id = 'Cats',
            options = [
                {'label': 't', 'value': 't'},
                {'label': 'y', 'value': 'y'},
                {'label': 'f', 'value': 'f'},
                {'label': 'j', 'value': 'j'},
                {'label': 'k', 'value': 'k'},
                {'label': 's', 'value': 's'},
            ],
            value = ['t', 'y', 'f', 'j', 'k', 's'],
        ),

        html.Label('Type', style = {'paddingTop': '2rem', 'display': 'inline-block'}),
        dcc.RadioItems(['Scatter', 'Heatmap', 'Hexbin'], 'Scatter', 
                       inline = True, id = 'maps'),

        html.Label('Opacity', style = {'paddingTop': '2rem', 'display': 'inline-block'}),
        dcc.Slider(0, 1, 0.1,
               value = 0.5,
               id = 'my-slider'),

        html.Label('nx_hexagon', style = {'paddingTop': '2rem', 'display': 'inline-block'}),
        dcc.Slider(0, 40, 5,
               value = 20,
               id = 'my-slider'),

        html.Label('min count', style = {'paddingTop': '2rem', 'display': 'inline-block'}),
        dcc.Slider(0, 5, 1,
               value = 0,
               id = 'my-slider'
    ),
    ],
    )
])

app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.Div(filtering),
        ], width = 2),
        dbc.Col([
            html.Div(dcc.Graph())
            ]),
        dbc.Col([
            html.Div(dcc.Graph(id = 'chart'))
            ])
    ])
], fluid = True)

df = data

@app.callback(
    Output('chart', 'figure'),
    [Input("Cats", "value"),
    Input("maps", "value")])

def scatter_chart(cats,maps):

    if maps == 'Scatter':

        dff = df[df['Cat'].isin(cats)]
        data = px.scatter_mapbox(data_frame = dff,
                                 lat = 'LAT',
                                 lon = 'LON',
                                 color = 'Cat',
                                 color_discrete_map = Type_cats,
                                 zoom = 3,
                                 mapbox_style = 'carto-positron',
                                )

        fig = go.Figure(data = data)

    elif maps == 'Heatmap':

        dff = df[df['Cat'].isin(cats)]

        # Creating 2-D grid of features
        [X, Y] = np.meshgrid(dff['LAT'], dff['LON'])

        Z = np.cos(X / 2) + np.sin(Y / 4)

        fig = go.Figure(data =
           go.Densitymapbox(lat = dff['LON'],
                            lon = dff['LAT'],
                            z = Z,
                            )
                    )

        fig.update_layout(mapbox_style = "carto-positron")

    else:

        dff = df[df['Cat'].isin(cats)]

        fig = ff.create_hexbin_mapbox(data_frame = dff, 
                                      lat = "LAT", 
                                      lon = "LON",
        )

    return fig

if __name__ == '__main__':
    app.run_server(debug=True, port = 8051)

Intended output when hexbin is selected:

enter image description here


Solution

  • I think you just use slider value as Input and then when you choosing Hexbin, you can pass it to graph setting. Something as below:

    import dash
    from dash import dcc
    from dash import html
    from dash.dependencies import Input, Output
    import dash_bootstrap_components as dbc
    import plotly.express as px
    import plotly.graph_objs as go
    import pandas as pd
    import numpy as np
    import plotly.figure_factory as ff
    
    data = pd.DataFrame({
        'Cat': ['t', 'y', 'y', 'y', 'f', 'f', 'j', 'k', 'k', 'k', 's', 's', 's', 's'],
        'LAT': [5, 6, 4, 5, 4, 7, 8, 9, 5, 6, 18, 17, 15, 16],
        'LON': [10, 11, 9, 11, 10, 8, 8, 5, 8, 7, 18, 16, 16, 17],
    })
    
    N = 30
    data = pd.concat([data] * N, ignore_index=True)
    
    data['Color'] = data['Cat'].map(
        dict(zip(data['Cat'].unique(), px.colors.qualitative.Plotly[:len(data['Cat'].unique())])))
    
    Color = data['Color'].unique()
    
    Type_Category = data['Cat'].unique()
    Type_cats = dict(zip(Type_Category, Color))
    
    external_stylesheets = [dbc.themes.SPACELAB, dbc.icons.BOOTSTRAP]
    
    app = dash.Dash(__name__, external_stylesheets=external_stylesheets,suppress_callback_exceptions=True)
    
    app.layout = dbc.Container([
        dbc.Row([
            dbc.Col([html.Label('Cats', style={'paddingTop': '2rem', 'display': 'inline-block'}),
                     dcc.Checklist(
                         id='Cats',
                         options=[
                             {'label': 't', 'value': 't'},
                             {'label': 'y', 'value': 'y'},
                             {'label': 'f', 'value': 'f'},
                             {'label': 'j', 'value': 'j'},
                             {'label': 'k', 'value': 'k'},
                             {'label': 's', 'value': 's'},
                         ],
                         value=['t', 'y', 'f', 'j', 'k', 's'],
                     ),
    
                     html.Label('Type', style={'paddingTop': '2rem', 'display': 'inline-block'}),
                     dcc.RadioItems(['Scatter', 'Heatmap', 'Hexbin'], 'Scatter',
                                    inline=True, id='maps'),
                     html.Div(id='filter')
                     ], width=2),
            dbc.Col([
                html.Div(dcc.Graph())
            ], width=5),
            dbc.Col([
                html.Div(dcc.Graph(id='chart'))
            ], width=5)
        ])
    ], fluid=True)
    
    df = data
    
    
    @app.callback(
        Output('filter', 'children'),
        [Input("maps", "value")])
    def update_slider(maps):
        if maps == 'Hexbin':
            return html.Div([html.Label('Opacity', style={'paddingTop': '2rem', 'display': 'inline-block'}),
                             dcc.Slider(0, 1, 0.1,
                                        value=0.5,
                                        id='my-slider'),
    
                             html.Label('nx_hexagon', style={'paddingTop': '2rem', 'display': 'inline-block'}),
                             dcc.Slider(0, 40, 5,
                                        value=20,
                                        id='my-slider2'),
    
                             html.Label('min count', style={'paddingTop': '2rem', 'display': 'inline-block'}),
                             dcc.Slider(0, 5, 1,
                                        value=0,
                                        id='my-slider3'
                                        )])
        else:
            return html.Div([html.Div(id='my-slider'),
                             html.Div(id='my-slider2'),
                             html.Div(id='my-slider3')])
    
    
    @app.callback(
        Output('chart', 'figure'),
        [Input("Cats", "value"),
         Input("maps", "value"),
         Input("my-slider", "value"),
         Input("my-slider2", "value"),
         Input("my-slider3", "value")],prevent_initial_call=True)
    def scatter_chart(cats, maps, my_slider, my_slider2, my_slider3):
        if maps == 'Scatter':
    
            dff = df[df['Cat'].isin(cats)]
            data = px.scatter_mapbox(data_frame=dff,
                                     lat='LAT',
                                     lon='LON',
                                     color='Cat',
                                     color_discrete_map=Type_cats,
                                     zoom=3,
                                     mapbox_style='carto-positron',
                                     )
    
            fig = go.Figure(data=data)
            return fig
        elif maps == 'Heatmap':
    
            dff = df[df['Cat'].isin(cats)]
    
            # Creating 2-D grid of features
            [X, Y] = np.meshgrid(dff['LAT'], dff['LON'])
    
            Z = np.cos(X / 2) + np.sin(Y / 4)
    
            fig = go.Figure(data=
                            go.Densitymapbox(lat=dff['LON'],
                                             lon=dff['LAT'],
                                             z=Z,
                                             )
                            )
    
            fig.update_layout(mapbox_style="carto-positron")
            return fig
        elif maps == 'Hexbin':
            dff = df[df['Cat'].isin(cats)]
            try:
                fig = ff.create_hexbin_mapbox(data_frame=dff,
                                              lat="LAT",
                                              lon="LON",
                                              opacity=my_slider,
                                              nx_hexagon=my_slider2,
                                              min_count=my_slider3)
    
                fig.update_layout(mapbox_style="open-street-map")
                return fig
            except TypeError:
                pass
    
    
    if __name__ == '__main__':
        app.run_server(debug=True, port=8051)
    

    enter image description here