Search code examples
python-3.xmatplotlibgrid-layoutfiguremultiple-axes

Can one get the number of rows and columns from an instance of the Figure class from matplotlib?


Question:

Suppose one would like to place 4 subplots in one figure - either as 4 rows by 1 column, or as 1 row by 4 columns. One can use fig, axes = plt.subplots(nrows=..., ncols=...) to initialize this subplot. But, inputting nrows=4, ncols=1 and inputting nrows=1, ncols=4 both give axes of the same axes.shape=(4,). Since these shapes are the same, how does matplotlib determine the number of rows and columns for the figure? Can nrows and ncols be obtained from an instance of fig or axes?

MWE:

In case the above is not clear, one can run the code below to create such a subplot (note the print statement):

import numpy as np
import matplotlib.pyplot as plt

## sample data
x = np.arange(10)
y1 = np.cos(x)
y2 = np.sin(x)
y3 = np.tan(x)
y4 = 1 / y3

## make easy to identify
labels = ('cos', 'sin', 'tan', r'$\frac{1}{tan}$')
facecolors = ('darkorange', 'steelblue', 'purple', 'green')

## initialize plot
# fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12,7)) ## shape=(2,2)
fig, axes = plt.subplots(nrows=4, ncols=1, figsize=(12,7)) ## shape=(4,)
# fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12,7)) ## shape=(4,)

## verify shape of axes
print(axes.shape)

## create a plot
for ax, y, label, facecolor in zip(axes.ravel(), (y1, y2, y3, y4), labels, facecolors):
    ax.plot(x, y, label=label, color=facecolor)

## add legend
fig.subplots_adjust(bottom=0.2)
fig.legend(loc='lower center', mode='expand', fontsize=8, ncol=4)

## fig.nrows outputs AttributeError: 'Figure' object has no attribute 'nrows'.

## show and close
plt.show()
plt.close(fig)

The answer to a different but related question mentions using the following solution - but it outputs an error:

for f in fig.get_children():
    print(f.colNum, f.rowNum)

# AttributeError: 'Rectangle' object has no attribute 'colNum'

I suppose one could iterate try-except loops to do this, but I am wondering if there is a cleaner way.


Solution

  • When you are calling plt.subplots() matplotlib uses a GridSpec to create the subplots. The figure itself can have several GridSpec besides the one use to create the initial axes, so you cannot get the GridSpec from the figure itself, but you can get it from the axes:

    fig, axes = plt.subplots(nrows=1, ncols=4) ## shape=(4,)
    gs = axes[0].get_gridspec()
    gs.nrows  # return  1
    gs.ncols  # returns 4