Search code examples
pythonmatplotlibdata-visualizationsquarify

How to add a subplot to each rectangle in a Tree Map?


I created this tree map using Matplotlib and Squarify: enter image description here

Now I would like to add a line plot on each rectangle in the tree map. Is that possible?


Solution

  • Squarify's plot is a convenience function to directly plot a treemap given values and labels. But, this process can also be executed step-by-step. One of the steps is to calculate the positions of the rectangles, for which we suppose a figure which has coordinates from 0,0 to 1,1 from lower left to upper right.

    With these rectangles we can manually position axes to draw on. It is unclear whether ticks are needed. If needed, they can be placed inside. Or the axes could be moved completely to the center of each subplot. Or only have ticks without labels.

    Here is some demonstrating code:

    import numpy as np
    import matplotlib.pyplot as plt
    import squarify
    
    values = [500, 433, 331, 254, 119]
    values.sort(reverse=True)  # values must be sorted descending (and positive)
    # the sum of the values must equal the total area to be laid out; i.e., sum(values) == width * height
    values = squarify.normalize_sizes(values, 1, 1)
    rects = squarify.squarify(values, 0, 0, 1, 1)
    
    fig = plt.figure(figsize=(7, 5))
    
    axes = [fig.add_axes([rect['x'], rect['y'], rect['dx'], rect['dy'], ]) for rect in rects]
    
    for ax, color in zip(axes, plt.cm.Pastel1.colors):
        x = np.linspace(0, 10, 100)
        y = np.random.normal(0.01, 0.1, 100).cumsum()
        ax.plot(x, y)
        ax.tick_params(axis="y", direction="in", pad=-15)
        ax.tick_params(axis="x", direction="in", pad=-15)
        plt.setp(ax.get_yticklabels(), ha="left")
        ax.set_facecolor(color)
    plt.show()
    

    resulting plot

    Here is another example resembling the image in the question, with a main plot and a colorbar. The default mplcursors gets confused with all these axes, but annotations while hovering can also be added manually.

    import numpy as np
    import matplotlib.pyplot as plt
    import squarify
    
    values = [4000, 1500, 1500, 1200, 1000, 500]
    
    fig, mainax = plt.subplots(figsize=(6, 4))
    
    mainax.set_xlim(0, 1000)
    mainax.set_ylim(0, 1000)
    mainax.grid(False)
    cmap = plt.cm.get_cmap('Greens')
    norm = plt.Normalize(vmin=0, vmax=max(values))
    plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm))
    pos = mainax.get_position()
    
    values.sort(reverse=True)
    normalized_values = squarify.normalize_sizes(values, pos.width, pos.height)
    rects = squarify.squarify(normalized_values, pos.x0, pos.y0, pos.width, pos.height)
    
    axes = [fig.add_axes([rect['x'], rect['y'], rect['dx'], rect['dy'], ]) for rect in rects]
    
    for ax, val in zip(axes, values):
        x = np.linspace(0, 10, 100)
        y = np.random.normal(0.01, 0.1, 100).cumsum()
        ax.plot(x, y)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_facecolor(cmap(norm(val)))
    mainax.set_facecolor('none') # prevent that the mainax blocks the other axes
    mainax.set_zorder(20) # high z-order because the annotations are drawn using this ax
    
    labels = ['a', 'b', 'c', 'd', 'e', 'f']
    sum_val = sum(values)
    annotations = [mainax.annotate(f"'{lbl}': {val}\n{val * 100.0 / sum_val:.1f} %",
                                   xy=(0, 0), xycoords='figure pixels',
                                   xytext=(0, 0), textcoords='offset points',
                                   bbox=dict(boxstyle='round', fc='lemonchiffon'),
                                   ha='center', va='bottom')
                   for ax, val, lbl in zip(axes, values, labels)]
    for annot in annotations:
        annot.set_visible(False)
    
    def hover(event):
        for ax, annot in zip(axes, annotations):
            if ax.bbox.contains(event.x, event.y):
                annot.xy = (event.x, event.y)
                annot.set_visible(True)
            else:
                annot.set_visible(False)
        fig.canvas.draw_idle()
    
    fig.canvas.mpl_connect("motion_notify_event", hover)
    
    plt.show()
    

    plot with colorbar