Search code examples
python-3.xmatplotlibsankey-diagram

Editing Python Matplotlib Sankey (removing values, align texts..)


I'm new to the Sankey diagram function in Matplotlib and I hope someone is able to help me with some things, since I cannot figure it out it seems.

First of all: I used the following example to understand how to create a Sankey diagram: https://flothesof.github.io/sankey-tutorial-matplotlib.html

This is the Sankey diagram:

enter image description here

Now there a few things that I need to change, two of which I don't know how.

  • I want to remove the labels that show the values in the middle of the diagram (so 14460, 9720, 7047, 3059 and 2149).
  • All texts in de middle are bold, except the last one (below in the yellow part). How can I make this one bold as well?

Hope someone can help.


Solution

  • The .text field of each diagram contains its central label as a Text object. The .texts is a list of the Text objects for each of the entering/leaving arrows. You can check the string to know which text it is about, and change properties such as bold or the xy-position.

    As a Sankey diagram is quite complex, fine-tuning the positions can happen once the diagram is created. A simple approach can be to add additional newlines in the strings.

    import matplotlib.pyplot as plt
    from matplotlib.sankey import Sankey
    
    fig, ax = plt.subplots(figsize=(8, 12))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title("My funnel")
    ax.set_axis_off()
    number = [round(12345 * 0.6 ** i) for i in range(6)]
    labels = ["\n\n\n\n\n\n\nTotal joined", "\n\n\n\nFirst", "\n\nSecond", "\nThird", "\nFourth", "\n\nReached final"]
    
    sankey = Sankey(ax=ax, scale=0.0015, offset=0.3)
    for input_number, output_number, label in zip(number[:-1], number[1:], labels):
        this_index = len(sankey.diagrams)
        prior = this_index - 1 if this_index > 0 else None
        pathlengths = [0, 0, 2 if this_index != 4 else 10]
        exitlabel = labels[-1] if this_index == 4 else None
        sankey.add(flows=[input_number, -output_number, output_number - input_number],
                   orientations=[0, 0, 1],
                   patchlabel=label,
                   labels=['', exitlabel, 'quit'],
                   prior=prior,
                   connect=(1, 0),
                   pathlengths=pathlengths,
                   trunklength=10.,
                   rotation=-90,
                   facecolor=plt.cm.spring((this_index + 1) / 5))
    diagrams = sankey.finish()
    for diagram in diagrams:
        diagram.text.set_fontweight('bold')
        diagram.text.set_fontsize('10')
        for text in diagram.texts:
            text.set_fontsize('10')
            label = text.get_text()  # can be normal label, "quit", a number or empty
            if len(label) > 0:
                if label[0].isdigit():
                    text.set_visible(False)
                elif label.startswith('quit'):
                    xy = text.get_position()
                    text.set_position((xy[0] + 2, xy[1]))
                else:
                    text.set_fontweight('bold')
    plt.ylim(ymin=plt.ylim()[0] * 1.05)
    plt.show()
    

    example plot