Search code examples
pythonmatplotlibsubplotimshowmatplotlib-gridspec

How to plot images in subplots


Suppose I have 3 directories of .jpg files: dataset 1, dataset 2, dataset 3.

I would like to make a 5 by 3 subplots using matplotlib. For each row, the subplot shows the data from dataset 1, dataset 2 and dataset 3 in order. The expected format is like this:

plot1, plot2, plot3,

plot4.......

plot13, plot14, plot15.

How should I do that?

something like this:

plt.figure(figsize=(10, 10)) 
for data1, data2, data3 in dataset1, dataset2, dataset3"
....

Solution

  • import matplotlib.pyplot as plt
    from pathlib import Path
    
    # create a list of directories
    dirs = ['../Pictures/dataset1', '../Pictures/dataset2', '../Pictures/dataset3']
    
    # extract the image paths into a list
    files = [f for dir_ in dirs for f in list(Path(dir_).glob('*.jpg'))]
    
    # create the figure
    fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(10, 10), tight_layout=True)
    
    # flatten the axis into a 1-d array to make it easier to access each axes
    axes = axes.flatten()
    
    # iterate through axes and associated file
    for ax, file in zip(axes, files):
        
        # read the image in
        pic = plt.imread(file)
    
        # add the image to the axes
        ax.imshow(pic)
    
        # add an axes title; .stem is a pathlib method to get the filename
        ax.set(title=file.stem)
    
        # remove ticks / labels
        ax.axis('off')
    
    # add a figure title
    _ = fig.suptitle('Images from https://www.heroforge.com/', fontsize=18)
    

    enter image description here


    No Whitespace Between Images

    # read in all the images, which are all the same size
    images = [plt.imread(file) for file in files]
    
    # get heights for images, the number must match the number for nrows
    heights = [im[0].shape[0] for im in images[:5]]  # [images[0][0].shape[0]] * 5
    
    # get widths for images, the number must match the number for ncols
    widths = [im.shape[1] for im in images[:3]]  # [images[0].shape[1]] * 3
    
    # set the figure width in inches
    fig_width = 9
    
    # calculate the figure width
    fig_height = fig_width * sum(heights) / sum(widths)
    
    # create the figure
    fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(fig_width, fig_height), 
                             gridspec_kw={'wspace': 0, 'hspace': 0, 'left': 0, 'right': 1,
                                          'bottom': 0, 'top': 1, 'height_ratios': heights})
    
    # flatten the axis into a 1-d array to make it easier to access each axes
    axes = axes.flatten()
    
    # iterate through the axes and associated images
    for ax, image in zip(axes, images):
    
        # add the image to the axes
        ax.imshow(image)
    
        # remove ticks / labels
        ax.axis('off')
    

    enter image description here