Search code examples
pythonplotlysankey-diagram

Prepping data for sankey plot in plotly


I have a data as follows:

df = pd.DataFrame({'Id': [1, 2, 3, 4], 'ColA': [30, 20, 20,30], 'ColB':[50, 20, 30,70], 
'ColC':[70, 30, 20,80]})

I want to prepare this for a sankey plot using plotly. I am not sure how to do the same. What I want to plot is essentially, Id's as bases and the column values as levels in the data. Added the image for reference. enter image description here


Solution

  • There are two many weights given the suggested nodes. Weights are reflected between the nodes, by the widths of the edges.

    If your Sankey has 12 nodes, only connected row-wise, then you can only have 9 weights. (One weight between each of the rectangles in your picture.) I arbitrarily chose to use the first three rows of your data, for columns a, b, and c.

    Nodes

    First, I created node labels, one for each of the 12. I used your column & row names and concatenated them. Notice that the output of this list is in order first by columns, then by rows. You don't have to do it this way, but you do need to refer to the node labels by index position, so putting them in a specific order now makes it easier later.

    Next, you need to create a source, target, and weight list.

    Source, Target, and Weights

    Your source and target list are index positions of the from and to nodes. So the first edge is between ColA id1 and ColA id2; therefore, the first source list element is 0; the first target list element is 1.

    Since the 4th (index 3) element is the last in the row and the rows aren't connected, I dynamically created the list of values for source and target, then remove the elements that coincided with the end of one row to the beginning of the next (i.e, ColA id4 does not connect to ColB id1).

    (If this doesn't make sense, comment out the code that contains .remove(), then plot. It should clarify things!)

    For the weights, I used .iloc to extract the columns other than id, and rows 1–3. To make this into a list, first I made it a numpy array, transposed it (rows to columns, columns to rows), then flattened it into a list. (This was done so it would be in the same order as the nodes list.)

    Finally...

    Finally, a plot.

    import plotly.graph_objects as go
    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame(
        {'Id': [1, 2, 3, 4], 'ColA': [30, 20, 20, 30], 'ColB':[50, 20, 30,70], 
        'ColC':[70, 30, 20,80]})
    
    # more weights than edges...using the first 3 rows for weights
    
    # using depicted col/row names as nodes:
    nds = [rws + str(cls) for rws in ["colA", "colB", "colC"] for cls in range(1, 5) ]
    print(nds)
    # ['colA1', 'colA2', 'colA3', 'colA4', 'colB1', 'colB2', 'colB3', 'colB4', 
    # 'colC1', 'colC2', 'colC3', 'colC4']
    
    # notice that index 0, is col 1, row 1—the first source entry
    # where's it going to? that's your target
    
    sre = list(range(0, 11))              # source list
    sre.remove(3)                               # remove row connectors
    sre.remove(7)
    trg = list(range(1, 12))              # target list
    trg.remove(4)
    trg.remove(8)
    
    # now the weights—the values in the data frame (leave out id, last row)
    wts = df.iloc[0:3, 1:4].to_numpy() 
    
    # transpose, so columns/rows are swapped, then flatten to a list
    wts = np.transpose(wts).flatten()
    
    # plot it
    fig = go.Figure([go.Sankey(
        node = dict(label = nds), link = dict(source = sre, target = trg, value = wts)
    )])
    
    fig.show() # gimme
    
    

    enter image description here