Search code examples
pythonmatplotlibsubplot

set a title for multiple subplots in matplotlib


I have a dataset of images, each record contains 2 images, and if they are of the same class or not (built from Fashion MNIST dataset).

I want to display the label ("match" or "missmatch") on each pair. My output so far is as follows:

enter image description here

My code:

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# pick 4 random pairs from the training set
random_indices = np.random.randint(0, len(train_pairs), size=4)
random_pairs = train_pairs[random_indices]
random_distance = train_distance[random_indices]

fig = plt.figure(figsize=(20, 10))
outer = gridspec.GridSpec(2, 2, wspace=0.2, hspace=0.2)

for i in range(4):
  inner = gridspec.GridSpecFromSubplotSpec(1, 2,
                  subplot_spec=outer[i], wspace=0.1, hspace=0.1)

  for j in range(2):
    ax = plt.Subplot(fig, inner[j])
    
    # show the image
    ax.imshow(random_pairs[i][j])

    # show the label
    ax.text(0, 0, '{}'.format(random_distance[i]),
            size=24, ha='center', va='center', color='w')

    ax.set_xticks([])
    ax.set_yticks([])
    fig.add_subplot(ax)

fig.show()

What i want is to display the label "match" or "missmatch" in the bottom center position between each pair of images.


Solution

  • I tried to work with sub figures and it gave the needed result, using supxlabel for each subfigure.

    %matplotlib inline
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    
    # pick 4 random pairs from the training set
    random_indices = np.random.randint(0, len(train_pairs), size=4)
    random_pairs = train_pairs[random_indices]
    random_distance = train_distance[random_indices]
    
    fig = plt.figure(figsize=(20, 10))
    
    subFigs = fig.subfigures(2, 2).flatten()
    print(subFigs)
    
    for i in range(4):
      subFig = subFigs[i]
      label = "Match" if random_distance[i] else "miss-Match"
      subFig.supxlabel(label, fontsize=16, color='red')
    
      axs = subFig.subplots(1, 2)
    
      for j in range(2):
        ax = axs[j]
        
        # show the image
        ax.imshow(random_pairs[i][j])
    
        ax.set_xticks([])
        ax.set_yticks([])
        subFig.add_subplot(ax)
    
    fig.show()
    

    The result obtained: enter image description here