Search code examples
pythonpandasdataframematplotlibsubplot

add grid to plt.subplots


I want to add grid to subplots of plt like: enter image description here

and this is my code(data is a dict of (string, padas DataFrame)):

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
for row in ((ax1, ax2), (ax3, ax4)):
    for axes in row:
        data['DPTP'].plot(ax=axes, x="time", y="fitness",
                          label="DPTP(r = 30)", linestyle='dashed', grid=True)

for i in range(2):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax2, x="time", y="fitness", label=lb, grid=True)
ax1.set_title("a")
ax1.grid(True)

for i in range(2, 4):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax3, x="time", y="fitness", label=lb, grid=True)
ax2.set_title("b")
ax2.grid(True)

for i in range(4, 6):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax3.set_title("c")
ax3.grid(True)

for i in range(6, 10):
    lb = "Preset {}".format(i)
    data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax4.set_title("d")
ax4.grid(True)

plt.show()

in the each for-loop I have grid=True and I think I can set each subplot to have grid, but my plot only has grid on the first subplot: ![enter image description here And even ax.set_title() does not work for subplots. Why?


Solution

  • You're somewhat new so I'll be kind and provide a solution but you should have provided the mock data to work on, that's your job when asking a question here.

    So here is the part you should have provided, construction of data:

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    
    # Mock data
    time = np.linspace(0.3, 2, 600)
    fitness = lambda i: 1 / (time + 0.05 * (i + 1)) - 0.1 * (i + 1)
    dataframe = lambda i: pd.DataFrame({"time": time, "fitness": fitness(i)})
    
    data = {"DPTP": dataframe(0)} | {str(i): dataframe(i + 1) for i in range(10)}
    

    Now let's create the axes, I just refactored a bit using dicts so plotting part is shorter and independant:

    # Prepare axes and axes metadata
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
    
    ax_titles = {ax1: "a", ax2: "b", ax3: "c", ax4: "c"}
    key_to_ax = (
        dict.fromkeys(["0", "1"], ax1)
        | dict.fromkeys(["2", "3"], ax2)
        | dict.fromkeys(["4", "5"], ax3)
        | dict.fromkeys(["6", "7", "8", "9"], ax4)
    )
    

    And finally the plotting part. Here you have two choices, either:

    • set grid=True as argument to all the df.plot calls
    • use ax.grid() once you're done plotting new lines on that ax.
    # Plot
    for ax in (ax1, ax2, ax3, ax4):
        data["DPTP"].plot(
            ax=ax, x="time", y="fitness", label="DPTP(r = 30)", linestyle="dashed"
        )
    
    for i in range(10):
        data[str(i)].plot(ax=key_to_ax[str(i)], x="time", y="fitness", label=f"Preset {i}")
    
    for ax in (ax1, ax2, ax3, ax4):
        ax.set_title(ax_titles[ax])
        ax.grid()
    
    plt.show()
    

    subplots with grids