Search code examples
python-3.xpandasplotlyplotly-python

plotly python Sankey Plot


I am trying to create a Sankey Diagram in python. The idea is to show the change of size of each Topic month on month.This is my pandas sample DataFrame. There are more Topic and also each Topic has more month and year. That makes the dataframe tall.

 df
    year    month   Topic   Document_Size   
  0 2022    1        0.0            63  
  1 2022    1        1.0            120 
  2 2022    1        2.0            106 
  3 2022    2        0.0            70  
  4 2022    2        1.0            42  
  5 2022    2        2.0            45  
  6 2022    3        0.0            78  
  7 2022    3        1.0            14  
  8 2022    3        2.0            84

I have prepared the following from plotly demo. I am missing the values that will go to the variables node_label, source_node, target_node so that the following code works. I am not getting the correct plot output

 node_label = ?
 source_node = ?
 target_node = ?
 values = df['Document_Size']

from webcolors import hex_to_rgb
%matplotlib inline

from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objects as go # Import the graphical object

fig = go.Figure( 
data=[go.Sankey( # The plot we are interest
    # This part is for the node information
    node = dict( 
        label = node_label
    ),
    # This part is for the link information
    link = dict(
        source = source_node,
        target = target_node,
        value = values
    ))])

 # With this save the plots 
 plot(fig,
 image_filename='sankey_plot_1', 
 image='png', 
 image_width=5000, 
 image_height=3000)

 # And shows the plot
 fig.show()

Solution

    • reusing this answer sankey from dataframe
    • restructure dataframe so that it has structure used in answer
    Document_Size source target
    63 2022 01 0.0 2022 02 0.0
    120 2022 01 1.0 2022 02 1.0
    106 2022 01 2.0 2022 02 2.0
    70 2022 02 0.0 2022 03 0.0
    42 2022 02 1.0 2022 03 1.0
    45 2022 02 2.0 2022 03 2.0
    import pandas as pd
    import io
    import numpy as np
    import plotly.graph_objects as go
    
    df = pd.read_csv(
        io.StringIO(
            """    year    month   Topic   Document_Size   
      0 2022    1        0.0            63  
      1 2022    1        1.0            120 
      2 2022    1        2.0            106 
      3 2022    2        0.0            70  
      4 2022    2        1.0            42  
      5 2022    2        2.0            45  
      6 2022    3        0.0            78  
      7 2022    3        1.0            14  
      8 2022    3        2.0            84"""
        ),
        sep="\s+",
    )
    
    
    # data for year and month
    df["date"] = pd.to_datetime(df.assign(day=1).loc[:, ["year", "month", "day"]])
    
    # index dataframe ready for constructing dataframe of source and target
    df = df.drop(columns=["year", "month"]).set_index(["date", "Topic"])
    
    dates = df.index.get_level_values(0).unique()
    
    # for each pair of current date and next date, construct segment of source /' target data
    df_sankey = pd.concat(
        [df.loc[s].assign(source=s, target=t) for s, t in zip(dates, dates[1:])]
    )
    
    df_sankey["source"] = df_sankey["source"].dt.strftime(
        "%Y %m "
    ) + df_sankey.index.astype(str)
    df_sankey["target"] = df_sankey["target"].dt.strftime(
        "%Y %m "
    ) + df_sankey.index.astype(str)
    
    nodes = np.unique(df_sankey[["source", "target"]], axis=None)
    nodes = pd.Series(index=nodes, data=range(len(nodes)))
    
    go.Figure(
        go.Sankey(
            node={"label": nodes.index},
            link={
                "source": nodes.loc[df_sankey["source"]],
                "target": nodes.loc[df_sankey["target"]],
                "value": df_sankey["Document_Size"],
            },
        )
    )
    

    output

    enter image description here