Search code examples
pythonnumpyindexingarray-broadcasting

Obtaining values from a 3D matrix based locations in a 2D matrix


Assuming I had a numpy matrix: data = np.random.rand(200, 50, 100) and had the locations I needed the values from: locs = np.random.randint(50, size=(200, 2)).

How would I obtain a resulting matrix of shape (200, 2, 100)? Essentially, I would like to obtain the values from data at the locations specified by locs.

If I do: data[locs], I end up with a resulting matrix of shape (200, 2, 50, 100) and not (200, 2, 100).

Updated with more details as requested:

If we have:

data = np.arange(125)
reshaped = np.reshape(data, (5, 5, 5))
locs = [[3, 4], [2, 1], [1, 3], [3, 3], [0, 0]]

Then doing something like data[locs] should give the following output:

array([[[ 15,  16,  17,  18,  19],
        [ 20,  21,  22,  23,  24]],

       [[ 35,  36,  37,  38,  39],
        [ 30,  31,  32,  33,  34]],

       [[ 55,  56,  57,  58,  59],
        [ 65,  66,  67,  68,  69]],

       [[ 90,  91,  92,  93,  94],
        [ 90,  91,  92,  93,  94]],

       [[100, 101, 102, 103, 104],
        [100, 101, 102, 103, 104]]])

Solution

  • The result of an advanced index is going to be the shape of the indices along the dimensions that you are indexing. data[locs] is equivalent to data[locs, :, :], so your shape will be locs.shape + data.shape[1:], or (200, 2, 50, 100).

    What you appear to be asking for is to index axis 1 of data using locs, keeping axis 0 in lockstep with the row in locs. To do this, you need to index with locs along axis 1, and supply an index that goes from 0 to 200 in axis zero.

    It is important to remember that all the advanced indices must broadcast to the same shape. Since locs is shaped (200, 2), the first index must be shaped (200, 1) or (200, 2) to broadcast properly. I will show the former, since it is simpler and more efficient.

    data = np.random.rand(200, 50, 100)
    locs = np.random.randint(50, size=(200, 2))
    rows = np.arange(200).reshape(-1, 1)
    
    result = data[rows, locs, :]