Search code examples
pythongraphnetworkxplotly-dashgraph-visualization

How to show the result of NetworkX graph in Plotly.Dash using Dash callbacks


I have a list of edges between several nodes of a graph as a .csv file. I am reading the mentioned file and storing it in a Dash Store component like below:

dataset = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
del dataset[dataset.columns[0]]
return  html.Div(className="mx-auto text-center", children=[
dcc.Store(id="approach-1-dataset",  data=dataset.to_dict('records'))]

Then using NetworkX the graph is created after the user clicks on a button in the view like below:

@app.callback(Output('visualization-container', 'children'),
          Input('visualize-button', 'n_clicks'),
          State('dataset', 'data'))
def visualize_graph(n1,dataset):
    if n1:
        main_dataset = pd.DataFrame.from_dict(dataset)
        pd.set_option('precision',10)
        G = nx.from_pandas_edgelist(main_dataset, 'member1', 'member2', create_using = nx.Graph())
        nodes = G.nodes()
        degree = G.degree()
        colors = [degree[n] for n in nodes]
        size = [(degree[n]) for n in nodes]
        pos = nx.kamada_kawai_layout(G)
        pos = nx.spring_layout(G, k = 0.2)
        cmap = plt.cm.viridis_r
        cmap = plt.cm.Greys
        fig = plt.figure(figsize = (15,9), dpi=100)
        nx.draw(G,pos,alpha = 0.8, nodelist = nodes, node_color = colors, node_size = size , with_labels= False,font_size = 6, width = 0.2, cmap = cmap, edge_color ='yellow')
        fig.set_facecolor('#0B243B')
        return dcc.Graph(figure = fig)
    return ""

Using this code I get the following error in my view:

Callback error updating visualization-container.children

dash.exceptions.InvalidCallbackReturnValue: The callback for <Output visualization-container.children> returned a value having type Graph which is not JSON serializable. The value in question is either the only value returned, or is in the top level of the returned list, and has string representation Graph(figure=<Figure size 1500x900 with 1 Axes>) In general, Dash properties can only be dash components, strings, dictionaries, numbers, None, or lists of those.

And this error in my console:

Assertion failed: (NSViewIsCurrentlyBuildingLayerTreeForDisplay() != currentlyBuildingLayerTree), function NSViewSetCurrentlyBuildingLayerTreeForDisplay, file NSView.m, line 13477.

It is worth mentioning that the same code works pretty fine when I run it directly in a Jupyter notebook but when I try to run it in a dash call back and return the result as a dcc.Graph component I get the errors.

How can I solve this issue?

My .csv file looks like below:

member1,member2,weight

200114206,3949436,1

217350178,8539046,1

.

.

.

193986670,8539046,2


Solution

  • You can convert the figure to a base64-encoded object, which you can display as an html-image in dash. Try the snippet below for a working example.

    But if you want to work with a 'graph' as in network-graph, don't use Dcc.Graph, but use cyto.Cytoscape. See https://dash.plotly.com/cytoscape

    import dash
    from dash import html, dcc
    from dash.dependencies import Output, Input
    
    import networkx as nx
    import pandas as pd
    import matplotlib.pyplot as plt
    import io
    import base64
    
    dataset = [
        (200114206,3949436,1),
        (217350178,8539046,1)
    ]
    
    main_dataset = pd.DataFrame(dataset, columns=["member1" ,"member2", "weight"])
    G = nx.from_pandas_edgelist(main_dataset, 'member1', 'member2', create_using = nx.Graph())
    nodes = G.nodes()
    degree = G.degree()
    colors = [degree[n] for n in nodes]
    size = [(degree[n]) for n in nodes]
    pos = nx.spring_layout(G, k = 0.2)
    cmap = plt.cm.Greys
    fig = plt.figure(figsize=(15,9), dpi=100)
    nx.draw(G, pos, alpha=0.8, nodelist=nodes, node_color=colors, node_size=size, with_labels=False, font_size=6, width=0.2, cmap=cmap, edge_color ='yellow')
    fig.set_facecolor('#0B243B')
    
    buffer = io.BytesIO()
    plt.savefig(buffer, format="jpg")
    buffer.seek(0)
    base64_encoded_image = base64.b64encode(buffer.read()).decode('utf-8')
    
    app = dash.Dash(__name__)
    server = app.server
    app.config.suppress_callback_exceptions = True
    
    app.layout = html.Div(children=[
        html.Button(id="visualisation_button", children="click me"),
        html.Div(id="visualisation_block")
        ]
    )
    
    @app.callback(
        Output("visualisation_block", "children")
        ,Input("visualisation_button", "n_clicks"))
    def update_vis(n_clicks):
        if n_clicks:
    
            base64_encoded_image
            return html.Img(id=f'nxplot_img',
                     src=f'data:image/png;base64, {base64_encoded_image}',
                     style = {'height': '50%', 'width': "50%"}
                     )
    
    app.run_server()