Search code examples
pythonnumpyvectorizationone-hot-encoding

How do I decode a one-hot encoded NumPy matrix in a fast manner using vectorization?


Given an image matrix of shape (height, width) with values in the uint8 range, which was one-hot encoded (converted to categorical) to a shape of (height, width, n) where n is the number of possible categories, 3 in this instance resulting in a shape of (height, width, 3), I would like to undo the categorical conversion and get the original shape of (height, width). The following solution works, but could be made much faster:

def decode(image):
    image = image

    height = image.shape[0]
    width = image.shape[1]

    decoded_image = numpy.ndarray(shape=(height, width), dtype=numpy.uint8)

    for i in range(0, height):
        for j in range(0, width):
            decoded_image[i][j] = numpy.argmax(image[i][j])

    return decoded_image

I would like a solution, using NumPy vectorization, without the need for a slower Python for loop.

Thank you for any suggestions.


Solution

  • Looks like you want to do a reduction over the last dimension of your array, in particular a numpy.argmax. Fortunately, this numpy function accepts an axis keyword, so you should be able to do the same in just one call:

    decoded_image = numpy.argmax(image, axis=2)