Search code examples
pythonrmatplotlibr-grid

Arrange plots that have subplots called from functions on grid in matplotlib


I am looking for something similar to arrangeGrob in R:

I have a function (say, function FUN1) that creates a plot with subplots. The number of subplots FUN1 creates may vary and the plot itself is quite complex. I have two other functions FUN2 and FUN3 which also create plots of varying structure.

Is there a simple way to define/arrange an overall GRID, for example a simple 3 rows 1 column style and simply pass

FUN1 --> GRID(row 1, col 1)
FUN2 --> GRID(row 2, col 1)
FUN3 --> GRID(row 3, col 1)

afterwards such that the complicated plot generated by FUN1 gets plotted in in row 1, the plot generated by FUN2 in row 2 and so on, without specifying the subplot criteria in the FUNs before?


Solution

  • The usual way to create plots with matplotlib would be to create some axes first and then plot to those axes. The axes can be set up on a grid using plt.subplots, figure.add_subplot, plt.subplot2grid or more sophisticated, using GridSpec.

    Once those axes are created, they can be given to functions, which plot content to the axes. The following would be an example where 6 axes are created and 3 different functions are used to plot to them.

    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    import numpy as np
    
    def func1(ax, bx, cx):
        x = np.arange(3)
        x2 = np.linspace(-3,3)
        y1 = [1,2,4]
        y2 = [3,2.5,3.4]
        f = lambda x: np.exp(-x**2)
        ax.bar(x-0.5, y1, width=0.4)
        ax.bar(x, y2, width=0.4)
        bx.plot(x,y1, label="lab1")
        bx.scatter(x,y2, label="lab2")
        bx.legend()
        cx.fill_between(x2, f(x2))
    
    def func2(ax, bx):
        x = np.arange(1,18)/1.9
        y = np.arange(1,6)/1.4
        z = np.outer(np.sin(x), -np.sqrt(y)).T
        ax.imshow(z, aspect="auto", cmap="Purples_r")
        X, Y = np.meshgrid(np.linspace(-3,3),np.linspace(-3,3))
        U = -1-X**2+Y
        V = 1+X-Y**2
        bx.streamplot(X, Y, U, V, color=U, linewidth=2, cmap="autumn")
    
    def func3(ax):
        data = [sorted(np.random.normal(0, s, 100)) for s in range(2,5)]
        ax.violinplot(data)
    
    
    gs = gridspec.GridSpec(3, 4, 
                    width_ratios=[1,1.5,0.75,1],  height_ratios=[3,2,2] )
    
    ax1 = plt.subplot(gs[0:2,0])
    ax2 = plt.subplot(gs[2,0:2])
    ax3 = plt.subplot(gs[0,1:3])
    ax4 = plt.subplot(gs[1,1])
    ax5 = plt.subplot(gs[0,3])
    ax6 = plt.subplot(gs[1:,2:])
    
    func1(ax1, ax3, ax5)
    func3(ax2)
    func2(ax4, ax6)
    
    plt.tight_layout()
    plt.show()
    

    enter image description here