Search code examples
matplotlibcolorbar

Discrete Color Bar with Tick labels in between colors


I am trying to plot some data with a discrete color bar. I was following the example given (https://gist.github.com/jakevdp/91077b0cae40f8f8244a) but the issue is this example does not work 1-1 with different spacing. For example, the spacing in the example in the link is for only increasing by 1 but my data is increasing by 0.5. You can see the output from the code I have.Example output from some random data. Any help with this would be appreciated. I know I am missing something key here but cant figure it out.

import matplotlib.pylab as plt
import numpy as np

def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)
num=11

x = np.random.randn(40)
y = np.random.randn(40)
c = np.random.randint(num, size=40)


plt.figure(figsize=(10,7.5))

plt.scatter(x, y, c=c, s=50, cmap=discrete_cmap(num, 'jet'))
plt.colorbar(ticks=np.arange(0,5.5,0.5))
plt.clim(-0.5, num - 0.5)
plt.show()

Solution

  • Not sure what version of matplotlib/pyplot introduced this, but plt.get_cmap now supports an int argument specifying the number of colors you want to get, for discrete colormaps. This automatically results in the colorbar being discrete. By the way, pandas has an even better handling of the colorbar.

    import numpy as np
    from matplotlib import pyplot as plt
    plt.style.use('ggplot')
    # remove if not using Jupyter/IPython
    %matplotlib inline
    
    # choose number of clusters and number of points in each cluster
    n_clusters = 5
    n_samples = 20
    
    # there are fancier ways to do this
    clusters = np.array([k for k in range(n_clusters) for i in range(n_samples)])
    
    # generate the coordinates of the center 
    # of each cluster by shuffling a range of values
    clusters_x = np.arange(n_clusters)
    clusters_y = np.arange(n_clusters)
    
    np.random.shuffle(clusters_x)
    np.random.shuffle(clusters_y)
    
    # get dicts like cluster -> center coordinate
    x_dict = dict(enumerate(clusters_x))
    y_dict = dict(enumerate(clusters_y))
    
    # get coordinates of cluster center for each point
    x = np.array(list(x_dict[k] for k in clusters)).astype(float)
    y = np.array(list(y_dict[k] for k in clusters)).astype(float)
    
    # add noise
    x += np.random.normal(scale=0.5, size=n_clusters*n_samples)
    y += np.random.normal(scale=0.5, size=n_clusters*n_samples)
    
    ### Finally, plot
    fig, ax = plt.subplots(figsize=(12,8))
    
    # get discrete colormap
    cmap = plt.get_cmap('viridis', n_clusters)
    
    # scatter points
    scatter = ax.scatter(x, y, c=clusters, cmap=cmap)
    
    # scatter cluster centers
    ax.scatter(clusters_x, clusters_y, c='red')
    
    # add colorbar
    cbar = plt.colorbar(scatter)
    
    # set ticks locations (not very elegant, but it works):
    # - shift by 0.5
    # - scale so that the last value is at the center of the last color
    tick_locs = (np.arange(n_clusters) + 0.5)*(n_clusters-1)/n_clusters
    cbar.set_ticks(tick_locs)
    
    # set tick labels (as before)
    cbar.set_ticklabels(np.arange(n_clusters))
    

    Sample output plot