Is it possible to embed a changing number of plots in a matplotlib axis? For example, the inset_axes
method is used to place inset axes inside parent axes:
However, I have several rows of plots and I want to include some inset axes inside the last axis object of each row.
fig, ax = plt.subplots(2,4, figsize=(15,15))
for i in range(2):
ax[i][0].plot(np.random.random(40))
ax[i][2].plot(np.random.random(40))
ax[i][3].plot(np.random.random(40))
# number of inset axes
number_inset = 5
for j in range(number_inset):
ax[i][4].plot(np.random.random(40))
Here instead of the 5 plots drawn in the last column, I want several inset axes containing a plot. Something like this:
The reason for this is that every row refers to a different item to be plotted and the last column is supposed to contain the components of such item. Is there a way to do this in matplotlib or maybe an alternative way to visualize this?
Thanks
As @hitzg mentioned, the most common way to accomplish something like this is to use GridSpec
. GridSpec
creates an imaginary grid object that you can slice to produce subplots. It's an easy way to align fairly complex layouts that you want to follow a regular grid.
However, it may not be immediately obvious how to use it in this case. You'll need to create a GridSpec
with numrows * numinsets
rows by numcols
columns and then create the "main" axes by slicing it with intervals of numinsets
.
In the example below (2 rows, 4 columns, 3 insets), we'd slice by gs[:3, 0]
to get the upper left "main" axes, gs[3:, 0]
to get the lower left "main" axes, gs[:3, 1]
to get the next upper axes, etc. For the insets, each one is gs[i, -1]
.
As a complete example:
import numpy as np
import matplotlib.pyplot as plt
def build_axes_with_insets(numrows, numcols, numinsets, **kwargs):
"""
Makes a *numrows* x *numcols* grid of subplots with *numinsets* subplots
embedded as "sub-rows" in the last column of each row.
Returns a figure object and a *numrows* x *numcols* object ndarray where
all but the last column consists of axes objects, and the last column is a
*numinsets* length object ndarray of axes objects.
"""
fig = plt.figure(**kwargs)
gs = plt.GridSpec(numrows*numinsets, numcols)
axes = np.empty([numrows, numcols], dtype=object)
for i in range(numrows):
# Add "main" axes...
for j in range(numcols - 1):
axes[i, j] = fig.add_subplot(gs[i*numinsets:(i+1)*numinsets, j])
# Add inset axes...
for k in range(numinsets):
m = k + i * numinsets
axes[i, -1][k] = fig.add_subplot(gs[m, -1])
return fig, axes
def plot(axes):
"""Recursive plotting function just to put something on each axes."""
for ax in axes.flat:
data = np.random.normal(0, 1, 100).cumsum()
try:
ax.plot(data)
ax.set(xticklabels=[], yticklabels=[])
except AttributeError:
plot(ax)
fig, axes = build_axes_with_insets(2, 4, 3, figsize=(12, 6))
plot(axes)
fig.tight_layout()
plt.show()