Search code examples
pythonmatplotlibfigure

Nesting already created matplotlib figures to a new one


Is there a way to nest already created figures inside a new figure in matplotlib so that they appear side to side?

Here is how I create my figures. I use pypianoroll library to plot my figures. generate_random_multitrack is just to create random examples.

import numpy as np
from pypianoroll import Multitrack, Track

def generate_random_multitrack():
    prs = np.random.randint(128, size=(3, 16, 128))
    tracks = []
    for pr in prs:
        tracks.append(
            Track(pianoroll=pr)
        )
    return Multitrack(tracks=tracks)
mt1 = generate_random_multitrack()
mt2 = generate_random_multitrack()

for example mt1.plot() returns a list of axes and creates this figure enter image description here

I would like to have something like this, where mt1 is on the left, mt2 is on the right enter image description here

I can access the figures and axes created by .plot call with

axes1 = mt1.plot()
fig1 = plt.gcf()
axes2 = mt2.plot()
fig2 = plt.gcf()

Is there a way to nest these two figures in a new matplotlib figure side to side?

import matplotlib.pyplot as plt
fig = plt.figure()
# ... Can I nest fig1 and fig2 side to side inside fig?

Thanks in advance for your help and comments.


Solution

  • According to their documentation, pypianoroll supports the axes argument. So, you could define your plot layout and fill the subplots with the tracks, e.g.,

    import numpy as np
    from pypianoroll import Multitrack, Track
    import matplotlib.pyplot as plt
    
    def generate_random_multitrack():
        prs = np.random.randint(128, size=(3, 16, 128))
        tracks = []
        for pr in prs:
            tracks.append(
                Track(pianoroll=pr)
            )
        return Multitrack(tracks=tracks)
    
    fig, axes = plt.subplots(3, 2, sharex=True)
    
    mt1 = generate_random_multitrack()
    mt2 = generate_random_multitrack()
    
    mt1.plot(axs=axes.flat[::2])
    mt2.plot(axs=axes.flat[1::2])
    plt.tight_layout()
    fig.subplots_adjust(hspace=0)
    plt.show()
    

    Sample output:

    enter image description here