Search code examples
pythonmatplotlibaxis

Is there a way of getting the inset axes by "asking" the axes it is embedded in?


I have several subplots, axs, some of them with embedded inset axes. I would like to get the data plotted in the insets by iterating over the main axes. Let's consider this minimal reproducible example:

fig, axs = plt.subplots(1, 3)
x = np.array([0,1,2])
for i, ax in enumerate(axs):
    if i != 1:
        ins = ax.inset_axes([.5,.5,.4,.4])
        ins.plot(x, i*x)
plt.show()

enter image description here

Is there a way of doing something like

data = []
for ax in axs:
    if ax.has_inset():       # "asking" if ax has embedded inset
        ins = ax.get_inset() # getting the inset from ax
        line = ins.get_lines()[0]
        dat = line.get_xydata()
        data.append(dat)
print(data)
# [array([[0., 0.],
#         [1., 0.],
#         [2., 0.]]),
#  array([[0., 0.],
#         [1., 2.],
#         [2., 4.]])]

Solution

  • You could use get_children and a filter to retrieve the insets:

    from matplotlib.axes import Axes
    
    def get_insets(ax):
        return [c for c in ax.get_children()
                if isinstance(c, Axes)]
    
    for ax in fig.axes:
        print(get_insets(ax))
    

    Output:

    [<Axes:label='inset_axes'>]
    []
    [<Axes:label='inset_axes'>]
    

    For your particular example:

    data = []
    for ax in fig.axes:
        for ins in get_insets(ax):
            line = ins.get_lines()[0]
            dat = line.get_xydata()
            data.append(dat)
    

    Output:

    [array([[0., 0.],
            [1., 0.],
            [2., 0.]]),
     array([[0., 0.],
            [1., 2.],
            [2., 4.]])]