Search code examples
pythonimageopencvmatplotlib

Displaying different gray scale images with actual size in matplotlib subplot


I use following function to display different images with its actual size in matplotlib subplot. I go this function from this reference: https://stackoverflow.com/a/53816322/22241489

def display_image(im_path):

    dpi = 80
    im_data = plt.imread(im_path)
    height, width, depth = im_data.shape

    # What size does the figure need to be in inches to fit the image?
    figsize = width / float(dpi), height / float(dpi)

    # Create a figure of the right size with one axes that takes up the full figure
    fig = plt.figure(figsize=figsize)
    ax = fig.add_axes([0, 0, 1, 1])

    # Hide spines, ticks, etc.
    ax.axis('off')

    # Display the image.
    ax.imshow(im_data, cmap='gray')

    plt.show()

But the problem is when I convert image to grayscale using opencv as follows

base_image = cv.imread(image)

gray = cv.cvtColor(base_image, cv.COLOR_BGR2GRAY)

cv.imwrite('temp/gray.jpg', gray)

then I can't use above function cause I got an error like this

ValueError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 display_image('temp/gray.jpg')

Cell In[10], line 5, in display_image(im_path)
      3 dpi = 80
      4 im_data = plt.imread(im_path)
----> 5 height, width, depth = im_data.shape
      7 # What size does the figure need to be in inches to fit the image?
      8 figsize = width / float(dpi), height / float(dpi)

ValueError: not enough values to unpack (expected 3, got 2)

How I fix this issue?


Solution

  • When you convert to grayscale, the color dimension in the array no longer exists. In a color image, the dimension of the array are height x width x 3 (for a standard 3 color channel image). However, when in grayscale, we no longer have the 3 color, so the last dimension is flattened resulting in an image array of height x width.

    Since you aren't using the depth, you can collect it using tuple packing, which adds it to a list if it is there, and ignores it if it is not.

    def display_image(im_path):
    
        dpi = 80
        im_data = plt.imread(im_path)
        height, width, *depth = im_data.shape
    
        # What size does the figure need to be in inches to fit the image?
        figsize = width / float(dpi), height / float(dpi)
    
        # Create a figure of the right size with one axes that takes up the full figure
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0, 0, 1, 1])
    
        # Hide spines, ticks, etc.
        ax.axis('off')
    
        # Display the image.
        ax.imshow(im_data, cmap='gray')
    
        plt.show()