Search code examples
pythonmatplotlibaxis-labels

Specific axis formatting in matplotlib


I am trying to format the axis labels on a matplotlib graph in a very specific way. My code is the following:

import matplotlib.pyplot as plt

some_matrix = ...
alpha_values = list(np.power([2.0]*20, xrange(-12,8)))
gamma_values = list(np.power([2.0]*20, xrange(-12,8)))

plotted_matrix = plt.matshow(some_matrix)
plt.colorbar()
xtick_marks = np.arange(len(alpha_values))
ytick_marks = np.arange(len(gamma_values))
plt.xticks(xtick_marks,alpha_values)
plt.yticks(ytick_marks,gamma_values)
plt.xlabel('Alpha values',size='small')
plt.ylabel('Gamma values',size='small')

As you can see my x (alpha's) and y (gamma's) labels are all powers of 2. I would like to know if it is possible to display axis as powers of 2, i.e. have labels 2^-1, 2^1, 2^2, ... (if possible, with the power as a proper superscript).

I have tried to write down:

y_formatter = plt.ticker.ScalarFormatter('#**2')
plotted_matrix.yaxis.set_major_formatter(y_formatter)

But I get the error message 'AxesImage' object has no attribute 'yaxis'.


Solution

  • You can use a ticker.FuncFormatter to define the tick labels, and a ticker.MultipleLocator to place them.

    The error you get is because you are trying to set the tick formatter, on the AxesImage returned by matshow, not the Axes instance. Here we can use the matplotlib object-oriented approach to make things simpler.

    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker
    import numpy as np
    
    alpha_values = list(np.power([2.0]*20, xrange(-12,8)))
    gamma_values = list(np.power([2.0]*20, xrange(-12,8)))
    
    some_matrix = np.random.rand(len(alpha_values),len(gamma_values))
    
    fig,ax = plt.subplots()
    
    plotted_matrix = ax.matshow(some_matrix)
    fig.colorbar(plotted_matrix)
    
    def tick_format(x,pos):
        # Map range 0-20 to 2^-12 to 2^8
        return "$2^{{ {:.0f} }}$".format(x-12)
    
    ax.xaxis.set_major_locator(ticker.MultipleLocator(2))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(tick_format))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(tick_format))
    
    ax.set_xlabel('Alpha values',size='small')
    ax.set_ylabel('Gamma values',size='small')
    
    plt.show()
    

    enter image description here