I have a large pandas MultiIndex DataFrame that I would like to plot. A minimal example would look like:
import pandas as pd
years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
df = pd.DataFrame(0, index=index, columns=columns)
df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1
If I plot this using plt.spy
, I get:
However, the tick locations and labels are less than desirable. I would like the ticks to completely ignore the second level of the MultiIndex. Using IndexLocator
and IndexFormatter
, I'm able to do the following:
from matplotlib.ticker import IndexFormatter, IndexLocator
import matplotlib.pyplot as plt
ax = plt.gca()
plt.spy(df)
xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()
ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')
plt.show()
This gives me exactly what I want:
But here's the problem. My actual DataFrame has 15 years, 4,000 fields, 365 days, and 7 bands. If I actually label every single day, the labels would be illegible. I could place a tick every 50 days, but I would like the ticks to be dynamic so that when I zoom in, the ticks become more fine-grained. Basically what I'm looking for is a custom MultiIndexLocator
that combines the placement of IndexLocator
with the dynamism of MaxNLocator
.
Bonus: My data is really nice in the sense that there are always the same number of fields for every year and the same number of bands for every day. But what if this was not the case? I would love to contribute a generic MultiIndexLocator
and MultiIndexFormatter
to matplotlib
that works for any MultiIndex DataFrame.
Matplotlib does not know about dataframes or MultiIndex. It simply plots the data you supply. I.e. you get the same as if you were plotting the numpy array of data, spy(df.values)
.
So I would suggest to first set the extent of the image correctly such that you may use numeric tickers. Then a MaxNLocator
should work fine, unless you do not zoom in too much.
import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False
years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)
############
# Plotting
############
xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')
extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]
fig, ax = plt.subplots()
ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')
ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
plt.show()