I have a subplot of subplots. The outer subplot consists of one row by two columns, and the two inner subplots each consist of four rows and four columns. Suppose I wanted the legend labels that correspond to only the first 2x2
inner subplot. How can I go about doing this? My attempt is below:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
outerD = dict(nrows=1, ncols=2)
innerD = dict(nrows=2, ncols=2)
D = dict(inner=innerD, outer=outerD)
def initialize_dubsub(D, figsize=None):
""" """
fig = plt.figure(figsize=figsize)
outerG = gridspec.GridSpec(D['outer']['nrows'], D['outer']['ncols'], wspace=0.2, hspace=0.2, width_ratios=[5, 5])
axes = []
for n in range(D['inner']['nrows']):
inner = gridspec.GridSpecFromSubplotSpec(D['inner']['nrows'], D['inner']['ncols'], subplot_spec=outerG[n], wspace=0.25, hspace=0.3, width_ratios=[10, 10], height_ratios=[2, 2])
for m in range(D['inner']['nrows']*D['inner']['ncols']):
ax = plt.Subplot(fig, inner[m])
ax.plot([], [], label='{}x{}'.format(n, m))
ax.set_xticks([])
ax.set_yticks([])
axes.append(ax)
fig.add_subplot(ax)
# handles, labels = axes[:4].get_legend_handles_labels() # first 2x2
# fig.legend(handles=handles, labels=labels, loc='lower center')
fig.legend(loc='lower center', ncol=4, mode='expand')
plt.show()
plt.close(fig)
initialize_dubsub(D)
This code will output 8
handles
and 8
labels
, whereas I want 4
each. I commented out the get_legend_handles_labels()
method as this does not work on arrays.
I realize I can do ax.legend()
but I prefer to use fig.legend(...)
. How can I achieve this?
Rather than trying to call .get_legend_handles_labels
on the array of subplots you want, you can just loop over the axes in that array, and append the handles and labels from those four subplots to a list.
For example:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
outerD = dict(nrows=1, ncols=2)
innerD = dict(nrows=2, ncols=2)
D = dict(inner=innerD, outer=outerD)
def initialize_dubsub(D, figsize=None):
""" """
fig = plt.figure(figsize=figsize)
outerG = gridspec.GridSpec(D['outer']['nrows'], D['outer']['ncols'], wspace=0.2, hspace=0.2, width_ratios=[5, 5])
axes = []
for n in range(D['inner']['nrows']):
inner = gridspec.GridSpecFromSubplotSpec(D['inner']['nrows'], D['inner']['ncols'], subplot_spec=outerG[n], wspace=0.25, hspace=0.3, width_ratios=[10, 10], height_ratios=[2, 2])
for m in range(D['inner']['nrows']*D['inner']['ncols']):
ax = plt.Subplot(fig, inner[m])
ax.plot([], [], label='{}x{}'.format(n, m))
ax.set_xticks([])
ax.set_yticks([])
axes.append(ax)
fig.add_subplot(ax)
handles, labels = [], []
for ax in axes[:4]:
handles_, labels_ = ax.get_legend_handles_labels()
handles += handles_
labels += labels_
fig.legend(handles=handles, labels=labels, loc='lower center')
#fig.legend(loc='lower center', ncol=4, mode='expand')
plt.show()
plt.close(fig)
initialize_dubsub(D)