Search code examples
pythonnumpyvisualizationimage-segmentationsemantic-segmentation

How to visualize a matrix of categories as an RGB image?


I am using neural network to do semantic segmentation(human parsing), something like taking a photo of people as input and the neural network tells that every pixel is most likely to be head, leg, background or some other parts of human. The algorithm runs smoothly and giving a numpy.ndarray as output . The shape of the array is (1,23,600,400), where 600*400 is the resolution of the input image and 23 is the number of categories. The 3d matrix looks like a 23-layer stacked 2d matrices, where each layer using a matrix of float to tell the possibility that each pixel is of that category.

To visualize the matrix like the following figure, I used numpy.argmax to squash the 3d matrix into a 2d matrix that holds the index of the most possible category. But I don't have any idea how to proceed to get the visualization I want.

The desired visualization effect

EDIT

Actually, I can do it in a trivial way. That is, use a for loop to traverse through every pixel and assign a color to it to get a image. However, this is not a vectorized coding, since numpy has built-in way to speed up matrix manipulation. And I need to save CPU cycles for real time segmentation.


Solution

  • It's fairly easy. All you need to have is a lookup table mapping the 23 labels into unique colors. The easiest way is to have a 23-by-3 numpy array with each row storing the RGB values for the corresponding label:

    import numpy as np
    import matplotlib.pyplot as plt
    lut = np.random.rand(23, 3)   # using random mapping - but you can do better
    lb = np.argmax(prediction, axis=1)  # converting probabilities to discrete labels
    rgb = lut[lb[0, ...], :]  # this is all it takes to do the mapping.
    plt.imshow(rgb)
    plt.show()
    

    Alternatively, if you are only interested in the colormap for display purposes, you can use cmap argument of plt.imshow, but this will requires you to transform lut into a "colormap":

    from matplotlib.colors import LinearSegmentedColormap
    cmap = LinearSegmentedColormap.from_list('new_map', lut, N=23)
    plt.imshow(lb[0, ...], cmap=cmap)
    plt.show()