Search code examples
pythonnumpynumpy-ndarraynumpy-slicing

Indexing a batch of lists of vectors with a batch of indices, where there are an arbitrary number of batch dimensions


This question generalises a previously asked question: Indexing a list of list of vectors with a list of indices

Given some data represented as a array of dimensions (N1,..., Nk, L, H), and a batch of indices of dimensions (N1,..., Nk), how can I index the data such that the output is of dimensions (N1,..., Nk, H). Semantically, I'd like to replace the below for loops with a single NumPy call:

N1, N2 = (3, 3, 3), ()
L, H = 5, 2

data = np.arange(np.prod(N1) * L * H).reshape(*N1, L, H)
inds = np.arange(np.prod(N1), dtype=int).reshape(N1) % L
out = np.empty((*N1, H), dtype=data.dtype)
for ii in np.ndindex(N1):
    out[ii] = data[ii + (inds[ii],)]
assert out.shape == (3, 3, 3, 2)

data = np.arange(np.prod(N2) * L * H).reshape(*N2, L, H)
inds = np.arange(np.prod(N2), dtype=int).reshape(N2) % L
out = np.empty((*N2, H), dtype=data.dtype)
for ii in np.ndindex(N2):
    out[ii] = data[ii + (inds[ii],)]
assert out.shape == (2,)

Seems like flattening the batch dimensions could work, but is there utilities in NumPy capable of indexing as such?


Solution

  • You can extend the answer in the linked post by repeating np.arange(N_i) as the index for each batch dimension i, in a way that the final result broadcasts to form a (N1, ..., Nk) shaped array. Numpy provides various index tricks to help with this. While there are multiple options, I think np.ogrid might work best:

    import numpy as np
    
    N1, N2, N3, L, H = 5, 3, 4, 10, 6
    
    data = np.random.rand(N1, N2, N3, L, H)
    inds = np.random.randint(L, size=(N1, N2, N3))
    
    # We need an index tuple with 4 arrays that broadcast to (N1, N2, N3):
    idx = tuple(np.ogrid[:N1, :N2, :N3]) + (inds,)
    
    out = data[idx]
    

    out.shape:

    (5, 3, 4, 6) # Which is (N1, N2, N3, H), as required
    

    There is also np.take_along_axis, though I find it a bit more awkward/less direct, because you need to explicitly align and then remove dimensions:

    out = np.take_along_axis(data, inds[..., None, None], axis=3).squeeze()