Search code examples
pythonpython-3.xpandasmatplotlibdata-visualization

Double header in Matplotlib Table


I need to plot a table in matplotlib. The problem is some columns have one-level headers, some columns have double-level headers.

Here's what I need:

Table needed

Here's simple example for one-level headers:

df = pd.DataFrame()
df['Animal'] = ['Cow', 'Bear']
df['Weight'] = [250, 450]
df['Favorite'] = ['Grass', 'Honey']
df['Least Favorite'] = ['Meat', 'Leaves']
df

enter image description here

fig = plt.figure(figsize=(9,2))
ax=plt.subplot(111)
ax.axis('off') 
table = ax.table(cellText=df.values, colColours=['grey']*df.shape[1], bbox=[0, 0, 1, 1], colLabels=df.columns)
plt.savefig('Table.jpg')

Last chunk of code produces next picture:

enter image description here

What changes do I need to make to have table I need?


Solution

  • Cell merge solution

    You can merge the cells produced by ax.table, a la the cell merge function in an Excel spreadsheet. This allows for a completely automated solution in which you don't need to fiddle with any coordinates (save for the indices of the cell you want to merge):

    import matplotlib.pyplot as plt
    import pandas as pd
    
    df = pd.DataFrame()
    df['Animal'] = ['Cow', 'Bear']
    df['Weight'] = [250, 450]
    df['Favorite'] = ['Grass', 'Honey']
    df['Least Favorite'] = ['Meat', 'Leaves']
    
    fig = plt.figure(figsize=(9,2))
    ax=fig.gca()
    ax.axis('off')
    r,c = df.shape
    
    # ensure consistent background color
    ax.table(cellColours=[['lightgray']] + [['none']], bbox=[0,0,1,1])
    
    # plot the real table
    table = ax.table(cellText=np.vstack([['', '', 'Food', ''], df.columns, df.values]), 
                     cellColours=[['none']*c]*(2 + r), bbox=[0, 0, 1, 1])
    
    # need to draw here so the text positions are calculated
    fig.canvas.draw()
    
    # do the 3 cell merges needed
    mergecells(table, (1,0), (0,0))
    mergecells(table, (1,1), (0,1))
    mergecells(table, (0,2), (0,3))
    

    Output:

    enter image description here

    Here's the code for the mergecells function used above:

    import matplotlib as mpl
    
    def mergecells(table, ix0, ix1):
        ix0,ix1 = np.asarray(ix0), np.asarray(ix1)
        d = ix1 - ix0
        if not (0 in d and 1 in np.abs(d)):
            raise ValueError("ix0 and ix1 should be the indices of adjacent cells. ix0: %s, ix1: %s" % (ix0, ix1))
    
        if d[0]==-1:
            edges = ('BRL', 'TRL')
        elif d[0]==1:
            edges = ('TRL', 'BRL')
        elif d[1]==-1:
            edges = ('BTR', 'BTL')
        else:
            edges = ('BTL', 'BTR')
    
        # hide the merged edges
        for ix,e in zip((ix0, ix1), edges):
            table[ix[0], ix[1]].visible_edges = e
    
        txts = [table[ix[0], ix[1]].get_text() for ix in (ix0, ix1)]
        tpos = [np.array(t.get_position()) for t in txts]
    
        # center the text of the 0th cell between the two merged cells
        trans = (tpos[1] - tpos[0])/2
        if trans[0] > 0 and txts[0].get_ha() == 'right':
            # reduce the transform distance in order to center the text
            trans[0] /= 2
        elif trans[0] < 0 and txts[0].get_ha() == 'right':
            # increase the transform distance...
            trans[0] *= 2
    
        txts[0].set_transform(mpl.transforms.Affine2D().translate(*trans))
    
        # hide the text in the 1st cell
        txts[1].set_visible(False)