I am looking for something similar to arrangeGrob in R:
I have a function (say, function FUN1
) that creates a plot with subplots. The number of subplots FUN1
creates may vary and the plot itself is quite complex. I have two other functions FUN2
and FUN3
which also create plots of varying structure.
Is there a simple way to define/arrange an overall GRID, for example a simple 3 rows 1 column style and simply pass
FUN1 --> GRID(row 1, col 1)
FUN2 --> GRID(row 2, col 1)
FUN3 --> GRID(row 3, col 1)
afterwards such that the complicated plot generated by FUN1
gets plotted in in row 1, the plot generated by FUN2 in row 2 and so on, without specifying the subplot criteria in the FUNs before?
The usual way to create plots with matplotlib would be to create some axes first and then plot to those axes. The axes can be set up on a grid using plt.subplots
, figure.add_subplot
, plt.subplot2grid
or more sophisticated, using GridSpec
.
Once those axes are created, they can be given to functions, which plot content to the axes. The following would be an example where 6 axes are created and 3 different functions are used to plot to them.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
def func1(ax, bx, cx):
x = np.arange(3)
x2 = np.linspace(-3,3)
y1 = [1,2,4]
y2 = [3,2.5,3.4]
f = lambda x: np.exp(-x**2)
ax.bar(x-0.5, y1, width=0.4)
ax.bar(x, y2, width=0.4)
bx.plot(x,y1, label="lab1")
bx.scatter(x,y2, label="lab2")
bx.legend()
cx.fill_between(x2, f(x2))
def func2(ax, bx):
x = np.arange(1,18)/1.9
y = np.arange(1,6)/1.4
z = np.outer(np.sin(x), -np.sqrt(y)).T
ax.imshow(z, aspect="auto", cmap="Purples_r")
X, Y = np.meshgrid(np.linspace(-3,3),np.linspace(-3,3))
U = -1-X**2+Y
V = 1+X-Y**2
bx.streamplot(X, Y, U, V, color=U, linewidth=2, cmap="autumn")
def func3(ax):
data = [sorted(np.random.normal(0, s, 100)) for s in range(2,5)]
ax.violinplot(data)
gs = gridspec.GridSpec(3, 4,
width_ratios=[1,1.5,0.75,1], height_ratios=[3,2,2] )
ax1 = plt.subplot(gs[0:2,0])
ax2 = plt.subplot(gs[2,0:2])
ax3 = plt.subplot(gs[0,1:3])
ax4 = plt.subplot(gs[1,1])
ax5 = plt.subplot(gs[0,3])
ax6 = plt.subplot(gs[1:,2:])
func1(ax1, ax3, ax5)
func3(ax2)
func2(ax4, ax6)
plt.tight_layout()
plt.show()