I am plotting a Sankey Diagram through plotly to compare different classifications of observations. However, I am having some issues with more than two classifications, where the order of observations in each classification changes between the inputs and the outputs of each node.
The code I am using is the following:
def pl_sankey(df, label_color, categories, value, title='Sankey Diagram', fname=None, width=3000, height=1600, scale=2):
from IPython.display import Image
import plotly.graph_objects as go
import pandas as pd
df = df.copy()
labels = []
colors = []
# associate labels to colors
for k, v in label_color.items():
labels += [k]
colors += [v]
# transform df into a source-target pair
st_df = None
for i in range(len(categories)-1):
_st_df = df[[categories[i],categories[i+1],value]]
_st_df.columns = ['source', 'target', 'count']
st_df = pd.concat([st_df, _st_df])
st_df = st_df.groupby(['source', 'target']).agg({'count': 'sum'}).reset_index()
# add index for source-target pair
st_df['sourceID'] = st_df['source'].apply(lambda x: labels.index(str(x)))
st_df['targetID'] = st_df['target'].apply(lambda x: labels.index(str(x)))
# creating the sankey diagram
data = dict(
type='sankey', node=dict(
pad=15, thickness=20, line = dict(color='black', width=0.5), label=labels, color=colors,
),
link=dict(source=st_df['sourceID'], target=st_df['targetID'], value=st_df['count']),
)
layout = dict(title=title, font=dict(size=16, family='Arial'))
# creating figure
fig = go.Figure(dict(data=[data], layout=layout))
if fname:
fig.write_image(f'{fname}.pdf', format='pdf', width=width, height=height, scale=scale)
return Image(fig.to_image(format='png', width=width, height=height, scale=scale))
The input parameters are:
df
with groupings for each set of rows, e.g.:# g1_l1 means group1, label1
g1 g2 g3 counts
0 g1_l1 g2_l1 g3_l1 10
1 g1_l3 g2_l2 g3_l1 1
2 g1_l1 g2_l2 g3_l2 1
3 g1_l2 g2_l2 g3_l1 40
4 g1_l2 g2_l3 g3_l2 20
5 g1_l3 g2_l1 g3_l2 10
label_color
is a dictionary, where keys are labels and values are colorscategories
are the column names of groupings, in this case ['grouping1', 'grouping2', 'grouping3']
values
is the column name of counts, in this case 'counts'
One example of execution is the following:
df = pd.DataFrame([
['g1_l1', 'g2_l1', 'g3_l1', 10],
['g1_l3', 'g2_l2', 'g3_l1', 1],
['g1_l1', 'g2_l2', 'g3_l2', 1],
['g1_l2', 'g2_l2', 'g3_l1', 40],
['g1_l2', 'g2_l3', 'g3_l2', 20],
['g1_l3', 'g2_l1', 'g3_l2', 10],
], columns=['g1', 'g2', 'g3', 'counts'])
label_color = {
'g1_l1': '#1f77b4', 'g1_l2': '#ff7f0e', 'g1_l3': '#279e68',
'g2_l1': '#1f77b4', 'g2_l2': '#ff7f0e', 'g2_l3': '#279e68',
'g3_l1': '#1f77b4', 'g3_l2': '#ff7f0e',
}
pl_sankey(df, label_color, categories=df.columns[:-1], value='counts', title='', fname=None)
However, this code guarantees row matching only between two adjacent columns. Consider for example, row 1:
g1 g2 g3 counts
1 g1_l3 g2_l2 g3_l1 1
Such row should start from green cluster (g1_l3
) on first column, land in orange cluster (g2_l2
) in second column and continue to blue cluster (g3_l1
) on third column. However, this is not respected in the previous plot, where input into the second column is not sorted similarly to matching output.
Attached the annotated plot to show the jumping of the observation in second column (such observation is second to last in input, but last in output in the second column):
I would like to follow the path of a row from the first to the last column. Is this possible and how to do it with Sankey diagram?
I might have misunderstood something completely here, but I'm hoping to guide you on the right way. So please forgive me if I'm wrong, but it seems you may have misunderstood some of the inner workings of a plotly sankey diagram. And don't worry, you're not alone.
You're stating that:
Such row should start from green cluster
(g1_l3)
on first column, land in orange cluster(g2_l2)
in second column and continue to blue cluster(g3_l1)
on third column
So if I understand correctly, you're expecting this particular relationship to be illustrated as:
But that's just not the way a plotly sankey diagram is set up to work. Rather, the quantities going from g1_l3
to g2_l2
are grouped together with the other quantities going into g2_l2
and then "sent" along as an aggregated value to g3_l1
. The reason why you have this line:
... is because you also have the relationship g2_l2 , g3_l1, 1
:
If you somehow were to succeed in illustrating the relationships in your dataframe exactly how you describe in a sankey figure, it would no longer be a sankey figure.
I'm sorry this is all I could do for you at the moment.