Search code examples
matplotlibsubplot

Subplots not taking arguments


I'm, trying to do a 3x9 subplots with a contourf from 3 lists, but when I plotted it it only apply the plot options (ylim, axis scaled, no axis tick) to the last plot, how can I instead apply to all the plots? Furthermore all the plots results one on top of each other as illustrated in the pic below, how can I spaced them properly?

fig, axes = plt.subplots(3, 9)
for j in range(len(axes[0])):
    
    levels = np.linspace(v_min, v_max, 21)
    for i in range(1, 19, 2):
        axes[0][j].contourf(V_avg[i], levels=levels, cmap=rgb_V)
    np.linspace(v_min, v_max, 11)
    
    for i in range(2, 20, 2):
        axes[1][j].contourf(V_avg[i], levels=levels, cmap=rgb_V)
    np.linspace(v_min, v_max, 11)
    
    levels = np.linspace(v_min_d_avg, v_max_d_avg, 21)
    for i in range(0, 9):
        axes[2][j].contourf(V_avg_dud[i], levels=levels, cmap=rgb_D_V)
    np.linspace(v_min_d_avg, v_max_d_avg, 11)

plt.xticks([])
plt.yticks([])
plt.axis('scaled')
plt.ylim([15, 90])
plt.savefig("aaa", dpi=300, bbox_inches='tight')
plt.show()

Thanks in advance.

enter image description here


Solution

  • You can use axes.flatten() to loop through the subplots.

    Here is a simple example:

    import numpy as np
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(3, 9, sharex=True, sharey=True)
    for ax in axes.flatten():
        x = np.random.randint(1, high=10, size=10)
        y = np.random.randint(1, high=10, size=10)
        ax.scatter(x, y)
    
    plt.show()
    

    Which produces this:

    enter image description here

    You don't give an example data set format, but here is an example where your data are in a dictionary:

    # Build example dataset
    keys = ["A", "B", "C", "D"]
    data_dict = {}
    for key in keys:
        data_dict[key] = [
            np.random.randint(1, high=10, size=10).tolist(),
            np.random.randint(1, high=10, size=10).tolist(),
        ]
    
    # Loop through dictionary for plotting
    fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
    axes = axes.flatten()
    for index, (key, value) in enumerate(data_dict.items()):
        x, y = value[0], value[1]
        axes[index].scatter(x, y)
        axes[index].set_title(key)
    
    plt.show()
    

    enter image description here