Search code examples
matplotlibcolorbar

Align ticks with matplotlib colorbar


I am trying to align the ticks on a colorbar with the discrete colors with little success. The following works but the ticks are not aligned with the colors.

import matplotlib.pyplot as plt
import numpy as np

import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
import matplotlib.cm as cm

from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as colors
from scipy import ndimage
from PIL import Image

plasma = mpl.colormaps['plasma'].resampled(8)
fig, ax = plt.subplots(figsize=(7, 5))

ax.axis("off")
img = plt.imread(image_file_name)
rotated_img = ndimage.rotate(img, -90)

im = ax.imshow(rotated_img, cmap=plasma)

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="10%", pad=0.05)

cb = fig.colorbar(im, cax=cax, orientation="vertical")
print(cb.ax.get_yticks())
cb.set_ticks(cb.ax.get_yticks())
ticks = np.linspace(70,160,num=7)
#cb.set_ticks(ticks)
cb.set_ticklabels([f'{tick:.2f}' for tick in ticks])
print(cb.ax.get_yticks())

yields

enter image description here

whereas

import matplotlib.pyplot as plt
import numpy as np

import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
import matplotlib.cm as cm

from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as colors
from scipy import ndimage
from PIL import Image

plasma = mpl.colormaps['plasma'].resampled(8)    
fig, ax = plt.subplots(figsize=(7, 5))

ax.axis("off")
img = plt.imread(image_file_name)
rotated_img = ndimage.rotate(img, -90)

im = ax.imshow(rotated_img, cmap=plasma)

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="10%", pad=0.05)

cb = fig.colorbar(im, cax=cax, orientation="vertical")
print(cb.ax.get_yticks())
#cb.set_ticks(cb.ax.get_yticks())
ticks = np.linspace(70,160,num=9)
cb.set_ticks(ticks)
cb.set_ticklabels([f'{tick:.2f}' for tick in ticks])
print(cb.ax.get_yticks())

yields

enter image description here

The colormap is obtained from the built in "plasma" cmap -

plasma = mpl.colormaps['plasma'].resampled(8)

The image data is simply RGB values - 0 to 255 for each channel. The "hottest" or whitest color corresponds to a maximum measured temperature.


Solution

  • This answer assumes you want to map the lowest color in the image mapped to 70, and the highest to 160.

    imshow() internally maps the lowest value in the data to the dark purple color of the colormap, and the highest value to yellow. We can make this explicit by setting vmin and vmax. These vmin and vmax will be the lowest and highest y value in the colorbar. The ticks are chosen to be a range of "nice, rounded" values between the lowest and highest value. This is independent of the number of colors in the colormap. .get_ticks() (or .get_yticks()) can be very tricky, as the final ticks are only decided at the moment the plot is drawn to screen. That's why you get the 300 value, which is outside the range, and which causes the strange white region if you use it in set_ticks().

    With cb.set_ticks(), you set the positions of the ticks. In this case, you want them from the lowest to the highest y-value, with one tick between each color region.

    With cb.set_ticklabels() you set the corresponding labels. As you want to make a transformation between the image values (e.g. between 0 and 255) and the ticks shown, you can use these new values as labels.

    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    import numpy as np
    from scipy.ndimage import gaussian_filter
    
    # create a dummy test image
    img = gaussian_filter(np.random.rand(1000, 1000), 80)
    img -= img.min()
    img /= img.max()
    img = np.array(img * 256, dtype=int)
    
    fig, ax = plt.subplots(figsize=(7, 5))
    
    ax.axis("off")
    num_colors = 8
    plasma = plt.get_cmap('plasma', num_colors)
    
    vmin = img.min()
    vmax = img.max()
    im = ax.imshow(img, cmap=plasma, vmin=vmin, vmax=vmax)
    
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="10%", pad=0.05)
    cb = fig.colorbar(im, cax=cax)
    cb.set_ticks(np.linspace(vmin, vmax, num_colors + 1))
    cb.set_ticklabels([f'{t:.2f}' for t in np.linspace(70, 160, num_colors + 1)])
    plt.tight_layout()
    plt.show()
    

    setting colorbar ticks to boundaries between colors