Search code examples
pythonpandasplotlysankey-diagramplotly-python

Plotly: How to plot Sankey diagram with matching rows across different columns?


I am plotting a Sankey Diagram through plotly to compare different classifications of observations. However, I am having some issues with more than two classifications, where the order of observations in each classification changes between the inputs and the outputs of each node.

The code I am using is the following:

def pl_sankey(df, label_color, categories, value, title='Sankey Diagram', fname=None, width=3000, height=1600, scale=2):
    from IPython.display import Image
    import plotly.graph_objects as go
    import pandas as pd
    df = df.copy()
    labels = []
    colors = []
    # associate labels to colors
    for k, v in label_color.items():
        labels += [k]
        colors += [v]
    # transform df into a source-target pair
    st_df = None
    for i in range(len(categories)-1):
        _st_df = df[[categories[i],categories[i+1],value]]
        _st_df.columns = ['source', 'target', 'count']
        st_df = pd.concat([st_df, _st_df])
        st_df = st_df.groupby(['source', 'target']).agg({'count': 'sum'}).reset_index()
    # add index for source-target pair
    st_df['sourceID'] = st_df['source'].apply(lambda x: labels.index(str(x)))
    st_df['targetID'] = st_df['target'].apply(lambda x: labels.index(str(x)))
    # creating the sankey diagram
    data = dict(
        type='sankey', node=dict(
            pad=15, thickness=20, line = dict(color='black', width=0.5), label=labels, color=colors,
        ),
        link=dict(source=st_df['sourceID'], target=st_df['targetID'], value=st_df['count']),
    )
    layout = dict(title=title, font=dict(size=16, family='Arial'))  
    # creating figure
    fig = go.Figure(dict(data=[data], layout=layout))
    if fname:
        fig.write_image(f'{fname}.pdf', format='pdf', width=width, height=height, scale=scale)
    return Image(fig.to_image(format='png', width=width, height=height, scale=scale))

The input parameters are:

  • a pandas DataFrame df with groupings for each set of rows, e.g.:
# g1_l1 means group1, label1

       g1      g2      g3   counts
0   g1_l1   g2_l1   g3_l1   10
1   g1_l3   g2_l2   g3_l1   1
2   g1_l1   g2_l2   g3_l2   1
3   g1_l2   g2_l2   g3_l1   40
4   g1_l2   g2_l3   g3_l2   20
5   g1_l3   g2_l1   g3_l2   10
  • label_color is a dictionary, where keys are labels and values are colors
  • categories are the column names of groupings, in this case ['grouping1', 'grouping2', 'grouping3']
  • values is the column name of counts, in this case 'counts'

One example of execution is the following:

df = pd.DataFrame([
    ['g1_l1', 'g2_l1', 'g3_l1', 10],
    ['g1_l3', 'g2_l2', 'g3_l1', 1],
    ['g1_l1', 'g2_l2', 'g3_l2', 1],
    ['g1_l2', 'g2_l2', 'g3_l1', 40],
    ['g1_l2', 'g2_l3', 'g3_l2', 20],
    ['g1_l3', 'g2_l1', 'g3_l2', 10],
], columns=['g1', 'g2', 'g3', 'counts'])

label_color = {
    'g1_l1': '#1f77b4', 'g1_l2': '#ff7f0e', 'g1_l3': '#279e68',
    'g2_l1': '#1f77b4', 'g2_l2': '#ff7f0e', 'g2_l3': '#279e68',
    'g3_l1': '#1f77b4', 'g3_l2': '#ff7f0e',
}

pl_sankey(df, label_color, categories=df.columns[:-1], value='counts', title='', fname=None)

sankey example

However, this code guarantees row matching only between two adjacent columns. Consider for example, row 1:

       g1      g2      g3   counts
1   g1_l3   g2_l2   g3_l1   1

Such row should start from green cluster (g1_l3) on first column, land in orange cluster (g2_l2) in second column and continue to blue cluster (g3_l1) on third column. However, this is not respected in the previous plot, where input into the second column is not sorted similarly to matching output.

Attached the annotated plot to show the jumping of the observation in second column (such observation is second to last in input, but last in output in the second column):

observation jumps

I would like to follow the path of a row from the first to the last column. Is this possible and how to do it with Sankey diagram?


Solution

  • I might have misunderstood something completely here, but I'm hoping to guide you on the right way. So please forgive me if I'm wrong, but it seems you may have misunderstood some of the inner workings of a plotly sankey diagram. And don't worry, you're not alone.

    You're stating that:

    Such row should start from green cluster (g1_l3) on first column, land in orange cluster (g2_l2) in second column and continue to blue cluster (g3_l1) on third column

    So if I understand correctly, you're expecting this particular relationship to be illustrated as:

    enter image description here

    But that's just not the way a plotly sankey diagram is set up to work. Rather, the quantities going from g1_l3 to g2_l2 are grouped together with the other quantities going into g2_l2 and then "sent" along as an aggregated value to g3_l1. The reason why you have this line:

    enter image description here

    ... is because you also have the relationship g2_l2 , g3_l1, 1:

    enter image description here

    If you somehow were to succeed in illustrating the relationships in your dataframe exactly how you describe in a sankey figure, it would no longer be a sankey figure.

    I'm sorry this is all I could do for you at the moment.