I am trying to format the axis labels on a matplotlib graph in a very specific way. My code is the following:
import matplotlib.pyplot as plt
some_matrix = ...
alpha_values = list(np.power([2.0]*20, xrange(-12,8)))
gamma_values = list(np.power([2.0]*20, xrange(-12,8)))
plotted_matrix = plt.matshow(some_matrix)
plt.colorbar()
xtick_marks = np.arange(len(alpha_values))
ytick_marks = np.arange(len(gamma_values))
plt.xticks(xtick_marks,alpha_values)
plt.yticks(ytick_marks,gamma_values)
plt.xlabel('Alpha values',size='small')
plt.ylabel('Gamma values',size='small')
As you can see my x
(alpha's) and y
(gamma's) labels are all powers of 2. I would like to know if it is possible to display axis as powers of 2, i.e. have labels 2^-1, 2^1, 2^2, ... (if possible, with the power as a proper superscript).
I have tried to write down:
y_formatter = plt.ticker.ScalarFormatter('#**2')
plotted_matrix.yaxis.set_major_formatter(y_formatter)
But I get the error message 'AxesImage' object has no attribute 'yaxis'
.
You can use a ticker.FuncFormatter
to define the tick labels, and a ticker.MultipleLocator
to place them.
The error you get is because you are trying to set the tick formatter, on the AxesImage
returned by matshow
, not the Axes
instance. Here we can use the matplotlib object-oriented approach to make things simpler.
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
alpha_values = list(np.power([2.0]*20, xrange(-12,8)))
gamma_values = list(np.power([2.0]*20, xrange(-12,8)))
some_matrix = np.random.rand(len(alpha_values),len(gamma_values))
fig,ax = plt.subplots()
plotted_matrix = ax.matshow(some_matrix)
fig.colorbar(plotted_matrix)
def tick_format(x,pos):
# Map range 0-20 to 2^-12 to 2^8
return "$2^{{ {:.0f} }}$".format(x-12)
ax.xaxis.set_major_locator(ticker.MultipleLocator(2))
ax.xaxis.set_major_formatter(ticker.FuncFormatter(tick_format))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_formatter(ticker.FuncFormatter(tick_format))
ax.set_xlabel('Alpha values',size='small')
ax.set_ylabel('Gamma values',size='small')
plt.show()