Search code examples
pythonmatplotlibaxis-labels

Is it possible to have some axis labels cover multiple columns in a matplotlib plot?


I's using matplotlib to create a heatmap with plt.imshow.

The Y-axis represents time, and is OK as is. The X-axis represents features and is the one I'd like to modify.

Some features are a 1:1 mapping of label and column, i.e.: The label length is associated with only one column

On the other hand, some features are a 1:n mapping of label and columns, i.e.: The label colors is associated with three columns, each representing a color.

What I would like to achieve is to have all the 1:n labels span the columns they are associated with, like this:

|-------|-------|-------|-------|-------|
|       |       |       |       |       |
|-------|-------|-------|-------|-------|
|       |       |       |       |       |
|-------|-------|-------|-------|-------|
|       |       |       |       |       |
|-------|-------|-------|-------|-------|

|_______|_______________________|_______|
    |               |               |   
 Length           Colors           Size 

Is this possible?

Thx in advance for help :-)


Solution

  • The following approach uses enlarged minor ticks to make separations and major ticks to put tick labels. Optionally the minor tick positions can also be used to draw gridlines as an extra separation.

    from matplotlib import pyplot as plt
    from matplotlib.ticker import FixedLocator
    import numpy as np
    
    plt.imshow(np.random.uniform(0, 1, (5, 5)), cmap='inferno')
    plt.tick_params(axis='x', which='major', length=0)
    plt.tick_params(axis='x', which='minor', length=15)
    plt.xticks([0, 2, 4], ['Length', 'Colors', 'Size'])
    plt.gca().xaxis.set_minor_locator(FixedLocator([-0.5, 0.5, 3.5, 4.5]))
    # plt.grid(axis='x', which='minor', color='white', lw=2)
    plt.show()
    

    example plot

    PS: The positions for minor and major ticks can be calculated from an array of widths:

    widths = np.array([1, 3, 1])
    bounds = np.insert(widths, 0, 0).cumsum() - 0.5
    ticks_pos = (bounds[:-1] + bounds[1:]) / 2 # np.convolve(bounds, [.5, .5], 'valid')