Search code examples
pythonarraysnumpymatrix-indexingdemosaicing

Flatten a 3D array to a 2D array using a second matrix to choose elements in third dimension


I have two input arrays: data_arr of dimensions (i,j,k) and index_arr of dimensions(i,j). The entries in index_arr are integers in the range [0, k-1]. I would like to create an output array (output_arr) of dimensions (i,j) where for each element of output_arr, index_arr tells me which of the elements to choose from.

In other words output_arr[i,j] = data_arr[i,j, index_arr[i, j]]

Clearly I could do this at glacial pace with a double for loop. I would prefer something snappier using smart indexing. Currently the best I could devise involves creating two extra 2D matrices of size (i,j).

Below is a simple MWE framed in terms of creating a mosaiced image from an RGB image using a standard bayer pattern. I would like to be able to get rid X_ind and Y_ind

import numpy as np
import time


if __name__ == '__main__':
    img_width = 1920
    img_height = 1080
    img_num_colours = 3

    red_arr = np.ones([img_height, img_width], dtype=np.uint16) * 10
    green_arr = np.ones([img_height, img_width], dtype=np.uint16) * 20
    blue_arr = np.ones([img_height, img_width], dtype=np.uint16) * 30

    img_arr = np.dstack((red_arr, green_arr, blue_arr))

    bayer_arr = np.ones([img_height, img_width], dtype=np.uint16)
    bayer_arr[0::2,0::2] = 0 # Red entries in bater patter
                             # Green entries are already set by np.ones intialisation
    bayer_arr[1::2,1::2] = 2 # blue entries in bayer patter
    print("bayer\n",bayer_arr[:8,:12], "\n")

    mosaiced_arr = np.zeros([img_height, img_width], dtype=np.uint16)
    Y_ind = np.repeat(np.arange(0, img_width).reshape(1, img_width), img_height, 0)
    X_ind = np.repeat(np.arange(0, img_height).reshape(img_height, 1), img_width, 1)

    start_time = time.time()
    demos_arr = img_arr[X_ind, Y_ind, bayer_arr]
    end_time = time.time()

    print(demos_arr.shape)
    print("demos\n",demos_arr[:8,:12], "\n")
    print("Mosaic took {:.3f}s".format(end_time - start_time)) 

Edit: As pointed out by @Georgy, this question is similar to this one which I didn't find with my search terms so maybe this post will act as a sign post for that one. The answers in the other post are applicable alhough the flattened index arithmetic is different since the ordering of my dimensions is different. The answer above is equivalent to the ogrid version in the other question. In fact ogrid can be used by replacing doing the following change to the code:

# Y_ind = np.repeat(np.arange(0, img_width).reshape(1, img_width), img_height, 0)
# X_ind = np.repeat(np.arange(0, img_height).reshape(img_height, 1), img_width, 1)
X_ind, Y_ind = np.ogrid[0:img_height, 0:img_width]

You can implement the choose option (limited to choosing between 32 options) like so:

start_time = time.time()
demos_arr = bayer_arr.choose((img_arr[...,0], img_arr[...,1], img_arr[...,2]))
end_time = time.time()

The ogrid solution runs in 12ms and the choose solution in 34ms on my machine


Solution

  • You want numpy.take_along_axis:

    output_arr = numpy.take_along_axis(data_arr, index_arr[:, :, numpy.newaxis], axis=2)
    output_arr = output_arr[:,:,0]  # Since take_along_axis keeps the same number of dimensions
    

    This function is new in numpy 1.15.0.

    https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.take_along_axis.html

    Note that data_arr and index_arr need to have the same number of dimensions. So you need to reshape the index_array to be 3 dimensions and reshape the result afterwards to be 2 dimensions again. I.e.:

    start_time = time.time()
    demos_arr = np.take_along_axis(img_arr, bayer_arr.reshape([img_height, img_width, 1]), axis=2).reshape([img_height, img_width])
    end_time = time.time()
    

    The timing results for take along axis are the same as the ogrid implementation.