Search code examples
pythonpython-3.xnumpymatplotlibplotly

Plot a 2D array with axes labelled with the array values of a 1D array


I have a 1D array arr1d = ['A', 'B', 'C', 'D'] and I have a 2D array

arr2d =[[1, 2, 5, 3], [2, 1, 2, 5], [5, 3, 4, 4], [5, 5, 3, 4]] (say) I wish to plot the array using matplotlib (or any other library) in such a way that the output is in the following manner. I want the x and y axis to get labelled as per the 1D array as shown in the picture. How to achieve that?

enter image description here


Solution

  • You could try something like this:

    from matplotlib import pyplot as plt
    from matplotlib import colors
    import numpy as np
    
    arr1d = ['A', 'B', 'C', 'D']
    data = np.array([[1, 2, 5, 3], [2, 1, 2, 5], [5, 3, 4, 4], [5, 5, 3, 4]])
    cmap = colors.ListedColormap(['blue','red', 'green', 'yellow', 'cyan'])
    fig, ax = plt.subplots()
    ax.imshow(data, cmap=cmap)
    
    for (i, j), z in np.ndenumerate(data):
        ax.text(j, i, '{}'.format(z), ha='center', va='center', size=12)
        
    plt.xticks(np.arange(len(arr1d)), arr1d)
    plt.yticks(np.arange(len(arr1d)), arr1d)
    plt.show()
    

    enter image description here

    Using ListedColormap, you can decide which number is mapped to which color.

    Or with grid lines:

    from matplotlib import pyplot as plt
    from matplotlib import colors
    import numpy as np
    from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
    
    arr1d = ['A', 'B', 'C', 'D']
    data = np.array([[1, 2, 5, 3], [2, 1, 2, 5], [5, 3, 4, 4], [5, 5, 3, 4]])
    cmap = colors.ListedColormap(['blue','red', 'green', 'yellow', 'cyan'])
    
    _, ax = plt.subplots()
    for (i, j), z in np.ndenumerate(data):
      ax.annotate('{}'.format(z), xy = (j + 0.4, i + 0.6), fontsize=15)
    
    ax.imshow(data, cmap=cmap, extent=(0, data.shape[0], data.shape[1], 0))
    ax.grid(color='black', linewidth=3)
    
    plt.xticks(np.arange(len(arr1d)), arr1d)
    plt.yticks(np.arange(len(arr1d)), arr1d)
    plt.show()
    

    enter image description here