Search code examples
pythonmatplotlibsubplotfigure

Adjusting gridspec so that plotted data aligns


I have several graphs to plot, all having the width a multiple of some unit as in the figure below. enter image description here

So the bottom axis is 1/4 of the whole width, the second-to-bottom one is 2/4 of the width etc.

The code I am using:

import matplotlib.pyplot as plt
divs = 4
fig = plt.figure()
gs = fig.add_gridspec(ncols = divs, nrows = divs)
axes = [fig.add_subplot(gs[div, div:]) for div in range(divs)]
for row in range(divs):
    axes[row].plot([1]*10*(divs - row), c = 'r')
    axes[row].set_xlabel('', fontsize = 6)
fig.set_figheight(10)
fig.set_figwidth(10)
plt.show()

My problem is that the plots don't exactly align as I want them to: The plot on row 2 begins slightly to the right of the '10' tick mark on the plot on row 1, and the same applies for the plot on row 3 vs the plot on row 2 etc. I would like the beginning of the plot on row 2 to synchronize precisely with the '10' on row 1, and likewise for the other plots. How is this achievable (not necessarily but preferably using gridspec)?

I tried adding axes[row].tick_params(axis="y",direction="in", pad=-22) to push the y-axis inside the plot but that didn't change the alignment. Also I tried using fig.tight_layout(pad = 0.3): this did not change the alignment either.


Solution

  • If you set the default value of the margin of the graph X-axis to 0, the ticks will match.

    import matplotlib.pyplot as plt
    divs = 4
    fig = plt.figure()
    gs = fig.add_gridspec(ncols = divs, nrows = divs)
    plt.rcParams['axes.xmargin'] = 0.0 #updated
    
    axes = [fig.add_subplot(gs[div, div:]) for div in range(divs)]
    for row in range(divs):
        axes[row].plot([1]*10*(divs - row), c = 'r')
        axes[row].set_xlabel('', fontsize = 6)
    
    fig.set_figheight(10)
    fig.set_figwidth(10)
    plt.show()
    

    enter image description here

    subplots(gridspec_kw=()...)

    import matplotlib.pyplot as plt
    divs = 4
    
    fig, axes = plt.subplots(4,1, gridspec_kw=dict(height_ratios=[1,1,1,1]), sharex='col', figsize=(10,10))
    for row in range(divs):
        axes[row].plot([1]*10*(divs - row), c = 'r')
        axes[row].set_xlabel('', fontsize = 6)
    
    plt.show()
    

    enter image description here