Search code examples
pythonmatplotlibtreemap

Showing change in a treemap in matplotlib


I am trying to create this:

Treemap in matplotlib

The data for the chart is:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


data = {
    "year": [2004, 2022, 2004, 2022, 2004, 2022],
    "countries" : [ "Denmark", "Denmark", "Norway", "Norway","Sweden", "Sweden",],
    "sites": [4,10,5,8,13,15]
}
df= pd.DataFrame(data)
df['diff'] = df.groupby(['countries'])['sites'].diff()
df['diff'].fillna(df.sites, inplace=True)

df

I am aware that there are packages that do treemaps, (squarify and plotly, to name some), but I have not figured out how to do the one above where the values of the years are added to each other. (or the difference to be exact) and it would be fantastic to learn how to do it in pure matplotlib, if it is not too complex.

Anyone has any pointers? I havent found a lot of info on treemaps on google.


Solution

  • There are two parts to this task.

    1. Computing a layout for the rectangles.
    2. Drawing the rectangles.

    The first part can get quite involved: people publish scientific papers on the topic. It's not advisable to re-invent the wheel here. However, the second part is quite straightforward and can be done in matplotlib.

    The solution below uses squarify to compute a layout using the larger value for each value pair, and then matplotlib to draw two rectangles on top of each other.

    enter image description here

    import numpy as np
    import matplotlib.pyplot as plt
    import squarify
    
    from matplotlib import colormaps
    from matplotlib.colors import to_rgba
    
    DEFAULT_COLORS = list(zip(colormaps["tab20"].colors[::2],
                              colormaps["tab20"].colors[1::2]))
    
    
    def color_to_grayscale(color):
        # Adapted from: https://stackoverflow.com/a/689547/2912349
        r, g, b, a = to_rgba(color)
        return (0.299 * r + 0.587 * g + 0.114 * b) * a
    
    
    class PairedTreeMap:
    
        def __init__(self, values, colors=DEFAULT_COLORS, labels=None, ax=None, bbox=(0, 0, 200, 100)):
            """
            Draw a treemap of value pairs.
    
            values : list[tuple[float, float]]
                A list of value pairs.
    
            colors : list[tuple[RGBA, RGBA]]
                The corresponding color pairs. Defaults to light/dark tab20 matplotlib color pairs.
    
            labels : list[str]
                The labels, one for each pair.
    
            ax : matplotlib.axes._axes.Axes
                The matplotlib axis instance to draw on.
    
            bbox : tuple[float, float, float, float]
                The (x, y) origin and (width, height) extent of the treemap.
    
            """
    
            self.ax = self.initialize_axis(ax)
            self.rects = self.get_layout(values, bbox)
            self.artists = list(self.draw(self.rects, values, colors, self.ax))
    
            if labels:
                self.labels = list(self.add_labels(self.rects, labels, values, colors, self.ax))
    
    
        def get_layout(self, values, bbox):
            maxima = np.max(values, axis=1)
            order = np.argsort(maxima)[::-1]
            normalized_maxima = squarify.normalize_sizes(maxima[order], *bbox[2:])
            rects = squarify.padded_squarify(normalized_maxima, *bbox)
            reorder = np.argsort(order)
            return [rects[ii] for ii in reorder]
    
    
        def initialize_axis(self, ax=None):
            if ax is None:
                fig, ax = plt.subplots()
            ax.set_aspect("equal")
            ax.axis("off")
            return ax
    
    
        def _get_artist_pair(self, rect, value_pair, color_pair):
            x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
            (small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
            ratio = np.sqrt(small / large)
            return (plt.Rectangle((x, y), w,         h,         color=color_large, zorder=1),
                    plt.Rectangle((x, y), w * ratio, h * ratio, color=color_small, zorder=2))
    
    
        def draw(self, rects, values, colors, ax):
            for rect, value_pair, color_pair in zip(rects, values, colors):
                large_patch, small_patch = self._get_artist_pair(rect, value_pair, color_pair)
                ax.add_patch(large_patch)
                ax.add_patch(small_patch)
                yield(large_patch, small_patch)
            ax.autoscale_view()
    
    
        def add_labels(self, rects, labels, values, colors, ax):
            for rect, label, value_pair, color_pair in zip(rects, labels, values, colors):
                x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
                # decide a fontcolor based on background brightness
                (small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
                ratio = small / large
                background_brightness = color_to_grayscale(color_large) if ratio < 0.33 else color_to_grayscale(color_small) # i.e. 0.25 + some fudge
                fontcolor = "white" if background_brightness < 0.5 else "black"
                yield ax.text(x + w/2, y + h/2, label, va="center", ha="center", color=fontcolor)
    
    
    if __name__ == "__main__":
    
        values = [
            (4, 10),
            (13, 15),
            (5, 8),
        ]
    
        colors = [
            ("red", "coral"),
            ("royalblue", "cornflowerblue"),
            ("darkslategrey", "gray"),
        ]
    
        labels = [
            "Denmark",
            "Sweden",
            "Norway"
        ]
    
        PairedTreeMap(values, colors=colors, labels=labels, bbox=(0, 0, 100, 100))
        plt.show()