Search code examples
pythonpickleparquetplotly-dashfeather

Faster serializations (pickle, parquet, feather, ...) than json in plotly dash Store?


Context

In a dashboard using plotly Dash I need to perform an expensive download from DB only when a component (DataPicker with the period to consider and so to be downloaded from DB) is updated and then use the resulting DataFrame with other components (e.g. Dropdowns filtering the DataFrame) avoiding the expensive download process.

The docs suggests to use dash_core_components.Store as Output of a callback that return the DataFrame serielized in json and than use the Store as Input of other callbacks that needs to deserialize from json to DataFrame.

Serialization from/to JSON is slow, and each time I update a component it takes 30 seconds to update the plot just for that.

I tried to use faster serializations like pickle, parquet and feather but in the deserialization part I get an error stating that the object is empty (when using JSON no such error appear).

Question

Is it possible to perform serializations in Dash Store with faster methods like pickle, feather or parquet (they takes approx half of time for my dataset) than JSON? How?

Code

import io
import traceback
import pandas as pd
from datetime import datetime, date, timedelta

import dash
import dash_core_components as dcc
import dash_html_components as html
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output
from plotly.subplots import make_subplots



app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
today = date.today()

app.layout = html.Div([
    dbc.Row(dbc.Col(html.H1('PMC'))),
    dbc.Row(dbc.Col(html.H5('analysis'))),
    html.Hr(),
    html.Br(),

    dbc.Container([
        dbc.Row([
            dbc.Col(
                dcc.DatePickerRange(
                    id='date_ranges',
                    start_date=today - timedelta(days=20),
                    end_date=today,
                    max_date_allowed=today, display_format='MMM Do, YY',
                ),
                width=4
            ),
        ]),
        dbc.Row(
            dbc.Col(
                dcc.Dropdown(
                    id='dd_ycolnames',
                    options=options,
                    value=default_options,
                    multi=True,
                ),
            ),
        ),
    ]),

    dbc.Row([
        dbc.Col(
            dcc.Graph(
                id='graph_subplots',
                figure={},
            ),
            width=12
        ),
    ]),

    dcc.Store(id='store')
])


@app.callback(
    Output('store', 'data'),
    [
        Input(component_id='date_ranges', component_property='start_date'),
        Input(component_id='date_ranges', component_property='end_date')
    ]
)
def load_dataset(date_ranges_start, date_ranges_end):
     # some expensive clean data step
     logger.info('loading dataset...')
     date_ranges1_start = datetime.strptime(date_ranges_start, '%Y-%m-%d')
     date_ranges1_end = datetime.strptime(date_ranges_end, '%Y-%m-%d')
     df = expensive_load_from_db(date_ranges1_start, date_ranges1_end)
     logger.info('dataset to json...')
     #return df.to_json(date_format='iso', orient='split')
     return df.to_parquet()                                 # <----------------------


@app.callback(
    Output(component_id='graph_subplots', component_property='figure'),
    [
        Input(component_id='store', component_property='data'),
        Input(component_id='dd_ycolnames', component_property='value'),
    ],
)
def update_plot(df_bin, y_colnames):
    logger.info('dataset from json')
    #df = pd.read_json(df_bin, orient='split')
    df = pd.read_parquet(io.BytesIO(df_bin))             # <----------------------
    logger.info('building plot...')
    traces = []
    for y_colname in y_colnames:
        if df[y_colname].dtype == 'bool':
            df[y_colname] = df[y_colname].astype('int')
        traces.append(
            {'x': df['date'], 'y': df[y_colname].values, 'name': y_colname},
        )
    fig = make_subplots(
        rows=len(y_colnames), cols=1, shared_xaxes=True, vertical_spacing=0.1
    )
    fig.layout.height = 1000
    for i, trace in enumerate(traces):
        fig.append_trace(trace, i+1, 1)
    logger.info('plotted')
    return fig


if __name__ == '__main__':
    app.run_server(host='localhost', debug=True)

Error text

OSError: Could not open parquet input source '<Buffer>': Invalid: Parquet file size is 0 bytes


Solution

  • Due to the exchange of data between client and server, you are currently limited to JSON serialization. One way to circumvent this limitation is via the ServersideOutput component from dash-extensions, which stores the data on the server. It uses file storage and pickle serialization by default, but you can use other storage (e.g. Redis) and/or serialization protocols (e.g. arrow) as well. Here is a small example,

    import time
    import dash_core_components as dcc
    import dash_html_components as html
    import plotly.express as px
    from dash_extensions.enrich import Dash, Output, Input, State, ServersideOutput
    
    app = Dash(prevent_initial_callbacks=True)
    app.layout = html.Div([
        html.Button("Query data", id="btn"), dcc.Dropdown(id="dd"), dcc.Graph(id="graph"),
        dcc.Loading(dcc.Store(id='store'), fullscreen=True, type="dot")
    ])
    
    
    @app.callback(ServersideOutput("store", "data"), Input("btn", "n_clicks"))
    def query_data(n_clicks):
        time.sleep(1)
        return px.data.gapminder()  # no JSON serialization here
    
    
    @app.callback(Input("store", "data"), Output("dd", "options"))
    def update_dd(df):
        return [{"label": column, "value": column} for column in df["year"]]  # no JSON de-serialization here
    
    
    @app.callback(Output("graph", "figure"), [Input("dd", "value"), State("store", "data")])
    def update_graph(value, df):
        df = df.query("year == {}".format(value))  # no JSON de-serialization here
        return px.sunburst(df, path=['continent', 'country'], values='pop', color='lifeExp', hover_data=['iso_alpha'])
    
    
    if __name__ == '__main__':
        app.run_server()