Search code examples
python-2.7matplotlibheatmapcorrelationaxis-labels

matplotlib correlation matrix heatmap with grouped colors as labels


I have a correlation matrix hat I am trying to visualize with matplotlib. I can create a heatmap style figure just fine, but I am running into problems with how I want the labels. I'm not even sure if this is possible, but this is what I'm trying to do and can't seem to make it work:

My correlation matrix is 150 X 150. On either the x or y (or both...this doesn't matter) axis, I would like to group the labels and then simply label them with a color, or a white label on a color background.

To clarify, let's say I'd like to have 1-15 as "Group 1" and either simply be a Blue bar, or "Group 1" text on a blue bar. Then 16-20 as "Group 2" on a red bar, or simply a red bar. Etc, through all of the items in the matrix.

I have been failing at both grouping axis labels as well as getting any color on them. Any help would be greatly appreciated. My code is below, though it's quite basic and I don't know if it will help.

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

# COREELATION MATRIX TEST #
corr = np.genfromtxt(csv_path,delimiter=',')
fig = plt.figure()
ax1 = fig.add_subplot(111)
cmap = cm.get_cmap('jet', 30)
cax = ax1.imshow(corr, cmap=cmap)
ax1.grid(True)
plt.title('THIS IS MY TITLE')
fig.colorbar(cax, ticks=[-1,-0.8,-0.6,-0.4,-0.2,0.0,0.2,0.4,0.6,0.8,1.0])
plt.show()

Solution

  • You may create auxilary axes next to the plot and plot colored bar plots to them. Turning the axes spines off lets those bars look like labelboxes.

    enter image description here

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    
    # COREELATION MATRIX TEST #
    corr = 2*np.random.rand(150,150)-1
    # labels [start,end]
    labels = np.array([[0,15],[16,36],[37,82],[83,111],[112,149]])
    colors = ["crimson", "limegreen","gold","orchid","turquoise"]
    
    fig, ax = plt.subplots()
    
    im = ax.imshow(corr, cmap="Blues")
    
    ax.set_title('THIS IS MY TITLE')
    fig.colorbar(im, ticks=[-1,-0.8,-0.6,-0.4,-0.2,0.0,0.2,0.4,0.6,0.8,1.0])
    
    # create axes next to plot
    divider = make_axes_locatable(ax)
    axb = divider.append_axes("bottom", "10%", pad=0.06, sharex=ax)
    axl = divider.append_axes("left", "10%", pad=0.06, sharey=ax)
    axb.invert_yaxis()
    axl.invert_xaxis()
    axb.axis("off")
    axl.axis("off")
    
    
    # plot colored bar plots to the axes
    barkw = dict( color=colors, linewidth=0.72, ec="k", clip_on=False, align='edge',)
    axb.bar(labels[:,0],np.ones(len(labels)), 
            width=np.diff(labels, axis=1).flatten(), **barkw)
    axl.barh(labels[:,0],np.ones(len(labels)), 
             height=np.diff(labels, axis=1).flatten(), **barkw)
    
    # set margins to zero again
    ax.margins(0)
    ax.tick_params(axis="both", bottom=0, left=0, labelbottom=0,labelleft=0)
    # Label the boxes
    textkw = dict(ha="center", va="center", fontsize="small")
    for k,l in labels:
        axb.text((k+l)/2.,0.5, "{}-{}".format(k,l), **textkw)
        axl.text(0.5,(k+l)/2., "{}-{}".format(k,l), rotation=-90,**textkw)
    
    plt.show()