Search code examples
pythonplotvisualizationsimilarityaxis-labels

How to plot image-image similarity matrix map?


I am trying to plot a image like the following. enter image description here

The x-axis and y-axis are a set of image. The the map measuring the pairwise similarity.

How can I plot this figure in python3. Basically, I have 50 images. it will be a 50x50 matrix map with 50 image in the x-axis and y-axis.

My goal is to plot a figure like (this one is asymmetric and the xlabel and ylabel are text, I hope mine is image.) enter image description here

Let's assume the similarity is random

import numpy as np

S = np.random.rand(3,3)

and have 3 image in some path.


Solution

  • I implemented next code snippet using libraries that can be installed through pip install numpy matplotlib seaborn tensorflow.

    Tensorflow was just used to load some example images, MNIST digits in my case, in your case you don't need tensorflow, you have your own images.

    Seaborn was used to draw HeatMap, it is a correlation like matrix.

    Inside code HeatMap() function is used to draw heatmap itself. ImgLabels() function was used to draw images as X/Y tick-labels, this functions uses images of MNIST hand-drawn digits obtained from Tensorflow's MNIST dataset, you can use any images instead. There is one tweakable param inside ImgLabels() function, it is numbers -18/-14/+14, this numbers depend on the size of squares with image labels, you may want to change them in your case.

    Similarity matrix is random in my case, I generate numbers in range [-1;+1]. Upper-right triangle of matrix is whitened because it is not necessary.

    If you don't need numbers to be drawn inside similarity cells then set annot = False instead of current annot = True.

    Next code draws such image:

    enter image description here

    Code:

    import numpy as np, matplotlib, matplotlib.pyplot as plt, seaborn as sns
    
    def ImgLabels(N, ax):
        imgs = None
        def offset_image(coord, name, ax):
            nonlocal imgs
            if imgs is None:
                import tensorflow as tf
                (imgs, _), (_, _) = tf.keras.datasets.mnist.load_data()
    
            img = imgs[name]
            im = matplotlib.offsetbox.OffsetImage(img, zoom = 0.9)
            im.image.axes = ax
    
            for co, xyb in [((0, coord), (-18, -14)),    ((coord, N), (+14, -18))]:
                ab = matplotlib.offsetbox.AnnotationBbox(im, co,  xybox = xyb,
                    frameon=False, xycoords='data',  boxcoords="offset points", pad=0)
                ax.add_artist(ab)
    
        for i, c in enumerate(range(N)):
            offset_image(i, c, ax)
    
    def HeatMap(N):
        sns.set_theme(style = "white")
        corr = np.random.uniform(-1, 1, size = (N, N))
        mask = np.triu(np.ones_like(corr, dtype = bool))
        cmap = sns.diverging_palette(230, 20, as_cmap=True)
        sns.heatmap(corr, mask=mask, cmap=cmap, vmin=-1, vmax=1, center=0,
                    square=True, annot = True, xticklabels = False, yticklabels = False)
    
    N = 10
    f, ax = plt.subplots(figsize=(7, 5))
    HeatMap(N)
    ImgLabels(N, ax)
    
    plt.show()