This question generalises a previously asked question: Indexing a list of list of vectors with a list of indices
Given some data represented as a array of dimensions (N1,..., Nk, L, H), and a batch of indices of dimensions (N1,..., Nk), how can I index the data such that the output is of dimensions (N1,..., Nk, H). Semantically, I'd like to replace the below for loops with a single NumPy call:
N1, N2 = (3, 3, 3), ()
L, H = 5, 2
data = np.arange(np.prod(N1) * L * H).reshape(*N1, L, H)
inds = np.arange(np.prod(N1), dtype=int).reshape(N1) % L
out = np.empty((*N1, H), dtype=data.dtype)
for ii in np.ndindex(N1):
out[ii] = data[ii + (inds[ii],)]
assert out.shape == (3, 3, 3, 2)
data = np.arange(np.prod(N2) * L * H).reshape(*N2, L, H)
inds = np.arange(np.prod(N2), dtype=int).reshape(N2) % L
out = np.empty((*N2, H), dtype=data.dtype)
for ii in np.ndindex(N2):
out[ii] = data[ii + (inds[ii],)]
assert out.shape == (2,)
Seems like flattening the batch dimensions could work, but is there utilities in NumPy capable of indexing as such?
You can extend the answer in the linked post by repeating np.arange(N_i)
as the index for each batch dimension i, in a way that the final result broadcasts to form a (N1, ..., Nk) shaped array. Numpy provides various index tricks to help with this. While there are multiple options, I think np.ogrid
might work best:
import numpy as np
N1, N2, N3, L, H = 5, 3, 4, 10, 6
data = np.random.rand(N1, N2, N3, L, H)
inds = np.random.randint(L, size=(N1, N2, N3))
# We need an index tuple with 4 arrays that broadcast to (N1, N2, N3):
idx = tuple(np.ogrid[:N1, :N2, :N3]) + (inds,)
out = data[idx]
out.shape
:
(5, 3, 4, 6) # Which is (N1, N2, N3, H), as required
There is also np.take_along_axis
, though I find it a bit more awkward/less direct, because you need to explicitly align and then remove dimensions:
out = np.take_along_axis(data, inds[..., None, None], axis=3).squeeze()