Search code examples
pythonmatplotlibmatplotlib-gridspec

How can I plot a 2D image and align its projection to the axes keeping the plots dimension small compared to the image?


I am struggling finding a way to keep the projection of the image completely aligned to the image (as in the figure below) but at the same time reducing their dimension so that the image take most of the figure space.

    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    import matplotlib.image as mpimg
    import numpy as np
    from skimage import data
    img = data.coins()
    h,w = img.shape
    ratio = h/w
    fig = plt.figure(figsize=(8, 8))
    gs = gridspec.GridSpec(2, 2, width_ratios=[1*ratio, 1], height_ratios=[1/ratio, 1])
    ax_center = plt.subplot(gs[1, 1])
    ax_center.imshow(img)
    ax_left = plt.subplot(gs[1, 0])
    ax_left.set_title('Left Plot')
    ax_left.plot(-img.mean(axis=1),range(img.shape[0]))
    ax_top = plt.subplot(gs[0, 1])
    ax_top.plot(img.mean(axis=0))
    ax_top.set_title('Top Plot')
    plt.tight_layout()
    plt.show()

enter image description here

Basically I would like the top plot to have a smalle height and the left top to have a smaller width keeping them perfectly aligned to the image.


Solution

  • You could do the following (by setting aspect="auto" the image could potentially become distorted, so in the example below I've tweaked the figure size appropriately to account for that):

    from matplotlib import pyplot as plt
    from skimage import data
    import numpy as np
    
    
    img = np.flipud(data.coins())
    
    shape = img.shape
    fwidth = 8  # set figure width
    fheight = fwidth * (shape[0] / shape[1])  # set figure height
    
    fig, ax = plt.subplots(
        2,
        2,
        sharex="col",
        sharey="row",
        width_ratios=[0.2, 1],  # set left subplot to be 20% width of image
        height_ratios=[0.2, 1],  # set top subplot to be 20% height of image
        figsize=[fwidth + 0.2 * fheight, fheight + 0.2 * fwidth],
    )
    
    # you need aspect="auto" to make sure axes align (although this will distort the image!)
    ax[1, 1].imshow(img, aspect="auto")
    
    ax[1, 0].plot(-img.mean(axis=1), range(img.shape[0]))
    ax[1, 0].set_title('Left Plot')
    
    ax[0, 1].plot(img.mean(axis=0))
    ax[0, 1].set_title('Top Plot')
    
    ax[1, 1].set_xlim([0, img.shape[1] - 1])
    ax[1, 1].set_ylim([0, img.shape[0] - 1])
    
    ax[0, 0].axis("off")
    
    fig.tight_layout()
    plt.show()
    

    This produces:

    enter image description here