Search code examples
pythonmatplotlibplotseaborncolor-palette

Define custom seaborn color palette?


I am trying to construct a color palette to disambiguate a large number of stacked bars. When I use any of the discrete color palettes (e.g. muted), the colors repeat, and when I use any of the continuous color maps (e.g. cubehelix) the colors run together.

Using muted : enter image description here

Using cubehelix : enter image description here

I need a color palette containing a large number of distinct non-contiguous colors. I think this can be achieved be taking an existing continuous color palette and permuting the colors, however I don't know how to do this, and despite much googling have not been able to figure out how to define a custom color palette.

Any help is much appreciated.


Solution

  • Matplotlib provides the tab20 colormap, which might be suitable here.

    Also you could take the colors from an existing colormap and randomize their order.

    Two tools that would allow to get a list of n distinct colors would be

    Comparing those three options:

    import numpy as np
    import matplotlib.pyplot as plt
    plt.rcParams["axes.xmargin"] = 0
    plt.rcParams["axes.ymargin"] = 0
    
    # Take the colors of an existing categorical map
    colors1 = plt.cm.tab20.colors
    
    # Take the randomized colors of a continuous map
    inx = np.linspace(0,1,20)
    np.random.shuffle(inx)
    colors2 = plt.cm.nipy_spectral(inx)
    
    # Take a list of custom colors
    colors3 = ["#9d6d00", "#903ee0", "#11dc79", "#f568ff", "#419500", "#013fb0", 
              "#f2b64c", "#007ae4", "#ff905a", "#33d3e3", "#9e003a", "#019085", 
              "#950065", "#afc98f", "#ff9bfa", "#83221d", "#01668a", "#ff7c7c", 
              "#643561", "#75608a"]
    
    fig = plt.figure()
    x = np.arange(10)
    y = np.random.rand(20, 10)+0.2
    y /= y.sum(axis=0)
    
    for i, colors in enumerate([colors1, colors2, colors3]):
        with plt.style.context({"axes.prop_cycle" : plt.cycler("color", colors)}):
            ax = fig.add_subplot(1,3,i+1)
            ax.stackplot(x,y)
    plt.show()
    

    enter image description here