Search code examples
python-3.xpandasmatplotlibcolorbarimshow

How to select specific number of colors to show in color bar from a big list ? - Matplotlib


I plotted some data which has 70 classes, so when I built the color bar it's very difficult to distinguish between each legend as shown below:

enter image description here

The code that I'm using is:

formation_colors = # 70 colors
formation_labels = # 70 labels
data = # the section of the entire dataset which only has 13  labels


data = data.sort_values(by='DEPTH_MD')
ztop=data.DEPTH_MD.min(); zbot=data.DEPTH_MD.max() 

cmap_formations = colors.ListedColormap(formation_colors[0:len(formation_colors)], 'indexed')
cluster_f = np.repeat(np.expand_dims(data['Formations'].values,1), 100, 1)

fig = plt.figure(figsize=(2,10))
ax = fig.add_subplot()
im_f = ax.imshow(cluster_f, interpolation='none', aspect='auto', cmap = cmap_formations, vmin=0, vmax=69)   
ax.set_xlabel('FORMATION')
ax.set_xticklabels(['']);

divider_f = make_axes_locatable(ax)
cax_f = divider_f.append_axes("right", size="20%", pad=0.05)
cbar_f = plt.colorbar(im_f, cax = cax_f,)

cbar_f.set_ticks(range(0,len(formation_labels))); cbar_f.set_ticklabels(formation_labels)

So far, if I just change:

   1. cmap_formations = colors.ListedColormap(formation_colors[0:len(formation_colors)], 'indexed') 
   2. cbar_f.set_ticks(range(0,len(formation_labels))); cbar_f.set_ticklabels(formation_labels)

to:

cmap_formations = colors.ListedColormap(formation_colors[0:len(data['FORMATION'].unique())], 'indexed') 

cbar_f.set_ticks(range(0,len(data['FORMATION'].unique()))); cbar_f.set_ticklabels(data['FORMATION'].unique())

I get, the corresponding colors in the cbar, however the plot is no longer correct and also the legends are out of square

enter image description here

Thank you so much if you have any idea how to do this.


Solution

  • Although not explicitly mentioned in the question, I suppose data['FORMATION'] contains indices from 0 till 69 into the lists of formation_colors and formation_labels

    The main problem is that data['FORMATION'] needs to be renumbered to be new indices (with numbers 0 till 12) into the new list of unique colors. np.unique(..., return_inverse=True) returns both the list of unique numbers, and the renumbering for the values.

    To be able to reindex the list of colors and of labels, it helps to convert them to numpy arrays.

    To make the code easier to debug, the following test uses a simple relation between the list of colors and the list of labels.

    from matplotlib import pyplot as plt
    from matplotlib import colors
    from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
    import numpy as np
    import pandas as pd
    
    formation_colors = np.random.choice(list(colors.CSS4_COLORS), 70, replace=False)  # 70 random color names
    formation_labels = ['lbl_' + c for c in formation_colors]  # 70 labels
    formation_colors = np.asarray(formation_colors)
    formation_labels = np.asarray(formation_labels)
    
    f = np.random.randint(0, 70, 13)
    d = np.sort(np.random.randint(0, 5300, 13))
    data = pd.DataFrame({'FORMATION': np.repeat(f, np.diff(np.append(0, d))),
                         'DEPTH_MD': np.arange(d[-1])})
    data = data.sort_values(by='DEPTH_MD')
    
    ztop = data['DEPTH_MD'].min()
    zbot = data['DEPTH_MD'].max()
    
    unique_values, formation_new_values = np.unique(data['FORMATION'], return_inverse=True)
    cmap_formations = colors.ListedColormap(formation_colors[unique_values], 'indexed')
    cluster_f = formation_new_values.reshape(-1, 1)
    
    fig = plt.figure(figsize=(3, 10))
    ax = fig.add_subplot()
    im_f = ax.imshow(cluster_f, extent=[0, 1, zbot, ztop],
                     interpolation='none', aspect='auto', cmap=cmap_formations, vmin=0, vmax=len(unique_values)-1)
    ax.set_xlabel('FORMATION')
    ax.set_xticks([])
    
    divider_f = make_axes_locatable(ax)
    cax_f = divider_f.append_axes("right", size="20%", pad=0.05)
    cbar_f = plt.colorbar(im_f, cax=cax_f)
    
    cbar_f.set_ticks(np.linspace(0, len(unique_values)-1, 2*len(unique_values)+1)[1::2])
    cbar_f.set_ticklabels(formation_labels[unique_values])
    
    plt.subplots_adjust(left=0.2, right=0.5)
    plt.show()
    

    Here is a comparison plot:

    example plot