Search code examples
pythonmatplotlibcyclecolormap

How to create a Matplotlib colormap that loops until a certain value?


So, I want to have a colormap that is the 'bone' colormap for the values from a to b, then continues as the reverse colormap of 'bone' from b to 2b-a (in such a way that it completes a cycle). Then I want this pattern to repeat n times and after that I want it to stop (which means that any value above the last value will be the same color as the last color in the colormap).

To give a specific example. Suppose a = 1, b = 5 and n = 3. My colormap then will look like bone from 1 to 5, then like reversed bone from 5 to 9, then like bone from 9 to 13, then like reversed bone from 13 to 17, then like bone from 17 to 21, then like reversed bone from 21 to 25, and finally all the values above 25 mapped to the last color in reversed bone.

How can I create this function?


Solution

  • A colormap's input range goes from 0 to 1. To apply a color to a given value, a norm must be used. Such a norm maps arbitrary values to the range 0-1. By default, plt.Normalize is used as norm, mapping the lowest value encountered in the input values to 0. This value is called vmin. Similarly, there is vmax, the highest value, to be mapped to 1.

    One could create a custom norm for your use case:

    import matplotlib.pyplot as plt
    from matplotlib.colors import Normalize
    import numpy as np
    
    
    class SawToothNorm(Normalize):
        def __init__(self, a, b, n, clip=False):
            self._a, self._b, self._n = a, b, n
            Normalize.__init__(self, vmin=a, vmax=a + (b - a) * 2 * n, clip=clip)
    
        def __call__(self, value, clip=None):
            a, b, n = self._a, self._b, self._n
            return np.where((value < self._vmin) | (value > self._vmax), 0,
                            1 - np.abs(((value - a) / (b - a)) % 2 - 1))
    
        def inverse(self, value):
            return value * (self._b - self._a) + self._a
    
    
    my_norm = SawToothNorm(a=1, b=5, n=3)
    
    fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True, figsize=(15, 6))
    
    ax1.imshow(np.linspace(-1, 30, 500).reshape(-1, 1),
               norm=my_norm, cmap='bone',
               extent=[0, 1, -1, 30], aspect='auto', origin='lower')
    ax1.set_title('SawToothNorm(a=1, b=5, n=3) applied on y-values')
    
    y = np.linspace(-1, 30, 500)
    ax2.plot(my_norm(y), y)
    ax2.set_title('the norm maps values to the range 0-1')
    ax2.set_xlabel('normalized range')
    ax2.set_ylabel('input value range')
    ax2.tick_params(labelleft=True)
    
    plt.tight_layout()
    plt.show()
    

    enter image description here