Search code examples
pythonpandasmatplotlibmulti-indexticker

Matplotlib: custom ticker for pandas MultiIndex DataFrame


I have a large pandas MultiIndex DataFrame that I would like to plot. A minimal example would look like:

import pandas as pd

years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']

index = pd.MultiIndex.from_product(
    [years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
    [days, bands], names=['day', 'band'])

df = pd.DataFrame(0, index=index, columns=columns)

df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1

If I plot this using plt.spy, I get:

simple plot

However, the tick locations and labels are less than desirable. I would like the ticks to completely ignore the second level of the MultiIndex. Using IndexLocator and IndexFormatter, I'm able to do the following:

from matplotlib.ticker import IndexFormatter, IndexLocator

import matplotlib.pyplot as plt

ax = plt.gca()
plt.spy(df)

xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()

ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')

plt.show()

This gives me exactly what I want:

enter image description here

But here's the problem. My actual DataFrame has 15 years, 4,000 fields, 365 days, and 7 bands. If I actually label every single day, the labels would be illegible. I could place a tick every 50 days, but I would like the ticks to be dynamic so that when I zoom in, the ticks become more fine-grained. Basically what I'm looking for is a custom MultiIndexLocator that combines the placement of IndexLocator with the dynamism of MaxNLocator.

Bonus: My data is really nice in the sense that there are always the same number of fields for every year and the same number of bands for every day. But what if this was not the case? I would love to contribute a generic MultiIndexLocator and MultiIndexFormatter to matplotlib that works for any MultiIndex DataFrame.


Solution

  • Matplotlib does not know about dataframes or MultiIndex. It simply plots the data you supply. I.e. you get the same as if you were plotting the numpy array of data, spy(df.values).

    So I would suggest to first set the extent of the image correctly such that you may use numeric tickers. Then a MaxNLocator should work fine, unless you do not zoom in too much.

    import numpy as np
    import pandas as pd
    from matplotlib.ticker import MaxNLocator
    import matplotlib.pyplot as plt
    plt.rcParams['axes.formatter.useoffset'] = False
    
    years = range(2000, 2018)
    fields = range(9) #17
    days = range(120) #365
    bands = ['R', 'G', 'B', 'A']
    
    index = pd.MultiIndex.from_product(
        [years, fields], names=['year', 'field'])
    columns = pd.MultiIndex.from_product(
        [days, bands], names=['day', 'band'])
    
    data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
    x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
    data += 2*((y//len(fields)+x//len(bands)) % 2)
    df = pd.DataFrame(data, index=index, columns=columns)
    
    ############
    # Plotting
    ############
    
    xbase = len(bands)
    xlabels = df.columns.get_level_values('day')
    ybase = len(fields)
    ylabels = df.index.get_level_values('year')
    
    extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
              xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
              ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
              ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]
    
    fig, ax = plt.subplots()
    
    ax.imshow(df.values, extent=extent, aspect="auto")
    ax.set_ylabel('Year')
    ax.set_xlabel('Day')
    
    ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
    
    
    plt.show()
    

    enter image description here