x
with shape [4, 5, 3]
i
with shape [4, 3]
, referring to indices along dimension 1 (of length 5) in x
y
from x
, with shape [4, 3]
, such that y[j, k] == x[j, i[j, k], k]
I think the correct answer is as follows:
y = x[np.arange(4).reshape(4, 1), i, np.arange(3).reshape(1, 3)]
Example:
import numpy as np
rng = np.random.default_rng(0)
x = np.arange(4 * 5 * 3)
rng.shuffle(x)
x = x.reshape(4, 5, 3)
i = rng.integers(5, size=[4, 3])
y = x[np.arange(4).reshape(4, 1), i, np.arange(3).reshape(1, 3)]
print("x:", x, "i:", i, "y:", y, sep="\n")
Output:
x:
[[[16 27 20]
[ 8 42 34]
[51 4 52]
[57 10 2]
[44 23 24]]
[[43 11 35]
[30 18 54]
[ 3 1 55]
[17 21 36]
[ 0 28 6]]
[[19 48 22]
[26 37 46]
[58 32 25]
[53 9 38]
[47 50 40]]
[[13 12 7]
[45 39 59]
[ 5 49 14]
[29 41 56]
[33 15 31]]]
i:
[[1 3 4]
[0 0 3]
[1 2 0]
[4 2 4]]
y:
[[ 8 10 24]
[43 11 36]
[26 32 22]
[33 49 31]]
(Rubber-duck debugging FTW)