Search code examples
pythonsankey-diagram

Python: Sankey plot chart with complex data


I have the following dataset:

data = {
    '1': ['A', 'B', 'C', 'NAN', 'A', 'C', 'NAN', 'C', 'B', 'A'],
    '2': ['B', 'NAN', 'A', 'B', 'C', 'A', 'B', 'NAN', 'A', 'C'],
    '3': ['NAN', 'A', 'B', 'C', 'NAN', 'B', 'A', 'B', 'C', 'A'],
    '4': ['C', 'B', 'NAN', 'A', 'B', 'NAN', 'C', 'A', 'NAN', 'B']
}
df = pd.DataFrame(data)

and I want to perform a simple Sankey plot of this data structure. I dont even know where to start...


Solution

  • It is tricky to get the data in the correct shape. Perhaps there is a more efficient way than what I have come up with, but hopefully this gets the job done.

    import plotly.graph_objects as go
    import pandas as pd
    import numpy as np
    data = {
        '1': ['A', 'B', 'C', 'NAN', 'A', 'C', 'NAN', 'C', 'B', 'A'],
        '2': ['B', 'NAN', 'A', 'B', 'C', 'A', 'B', 'NAN', 'A', 'C'],
        '3': ['NAN', 'A', 'B', 'C', 'NAN', 'B', 'A', 'B', 'C', 'A'],
        '4': ['C', 'B', 'NAN', 'A', 'B', 'NAN', 'C', 'A', 'NAN', 'B']
    }
    df = pd.DataFrame(data)
    df = df.replace('NAN', np.nan)
    
    # Get a list of labels by adding the column name to the the cell values and
    # getting the discict combinations
    label  = sorted(df.apply(lambda x: x+x.name).melt().dropna()['value'].unique())
    
    # Iterate over two columns at a time to map out the relationships
    output = []
    for i in range(1, df.shape[1]):
        output.extend(df[[str(i),str(i+1)]].value_counts().reset_index().apply(lambda x: x+x.name).values)
    
    # Convert the relationships to the index of the labels list
    mapped = []
    for x in output:
        mapped.append((label.index(x[0]), label.index(x[1]), x[2]))
    
    # Split the values into their corresponding buckets
    source, target, value = np.array(mapped).T
    
    # Build your chart
    fig = go.Figure(data=[go.Sankey(
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(color = "black", width = 0.5),
            label = label,
                color = "blue"
            
        ),
        link = dict(
          source = source,
            target = target,
            value = value
      ))])
    
    fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
    fig.show()
    

    Output

    enter image description here