Search code examples
pythonimagematplotlibsubplot

How to add an image as the background to a matplotlib figure (not to plots, but to the "whitespace" ala set_face() )


I am trying to add an image to the "whitespace" behind the various subplots of a matplotlib figure.

Most discussions similar to this topic are to add images to the plots themselves, however I have not yet come across a means to change the background of the overall "canvas".

The most similar function I have found is set_facecolor(), however this only allows a single color to be set as the background.

fig, ax = plt.subplots(2,2)
fig.patch.set_facecolor('xkcd:mint green')
plt.show()

current output

However, I am seeking a solution to import an image behind the plots, similar to this (manually made):

desired output

I have googled, searched SO, and looked through the matplotlib docs but I only get results for either plt.imshow(image) or set_facecolor() or similar.


Solution

  • You can use a dummy subplot, with the same size as the figure, and plot the background onto that subplot.

    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import numpy as np
    
    image = plt.imread('test.jpg') 
    
    # make ticks white, for readability on colored background
    mpl.rcParams.update({'xtick.color': "white",
                         'ytick.color': "white",
                         'axes.labelcolor': "white"}) 
    # create a figure with 4 subplots, with the same aspect ratio as the image
    width = 8
    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(width, width * image.shape[0] / image.shape[1]))
    for ax in np.ravel(axs):
        ax.patch.set_alpha(0.7) # make subplots semi-transparent
    
    background_ax = plt.axes([0, 0, 1, 1]) # create a dummy subplot for the background
    background_ax.set_zorder(-1) # set the background subplot behind the others
    background_ax.imshow(image, aspect='auto') # show the backgroud image
    
    # plot something onto the subplots
    t = np.linspace(0, 8 * np.pi, 2000)
    for i in range(2):
        for j in range(2):
            axs[i, j].plot(np.sin(t * (i + 2)), np.sin(t * (j + 4)))
    
    # plt.tight_layout() gives a warning, as the background ax won't be taken into account,
    # but normally the other subplots will be rearranged to nicely fill the figure
    plt.tight_layout() 
    plt.show()
    

    using an image as figure background