Let's say I have an array of permutations perm
which could look like:
perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])
If I want to apply it to one axis, I can write something like:
v = np.arange(9).reshape(3, 3)
print(v[perm])
Output:
array([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[3, 4, 5],
[6, 7, 8],
[0, 1, 2]],
[[0, 1, 2],
[6, 7, 8],
[3, 4, 5]],
[[6, 7, 8],
[3, 4, 5],
[0, 1, 2]]])
Now I would like to apply it to two axes at the same time. I figured out that I can do it via:
np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
But I find it quite inefficient, because it has to create a mesh grid, and it also requires a for loop. I made a small array in this example but in reality I have a lot larger arrays with a lot of permutations, so I would really love to have something that's as quick and simple as the one-axis version.
How about:
p1 = perm[:, :, np.newaxis]
p2 = perm[:, np.newaxis, :]
v[p1, p2]
The zeroth axis of p1
and p2
is just the "batch" dimension of perm
, which allows you to do many permutations in one operation.
The other dimension of perm
, which corresponds with the indices, is aligned along the first axis in p1
and the second in p2
. Because the axes are orthogonal, the arrays get broadcasted, basically like the arrays you got using meshgrid
- but these still have the batch dimension.
That's the best I can do from my cell phone : ) I can try to clarify later if needed, but the key idea is broadcasting.
Comparison:
import numpy as np
perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])
v = np.arange(9).reshape(3, 3)
ref = np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
p1 = perm[:, :, np.newaxis]
p2 = perm[:, np.newaxis, :]
res = v[p1, p2]
np.testing.assert_equal(res, ref)
# passes
%timeit np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
# 107 µs ± 20.6 µs per loop
%timeit v[perm[:, :, np.newaxis], perm[:, np.newaxis, :]]
# 3.73 µs ± 1.07 µs per loop
A simpler (without batch dimension) example of broadcasting indices:
import numpy as np
i = np.arange(3)
ref = np.meshgrid(i, i, indexing="ij")
res = np.broadcast_arrays(i[:, np.newaxis], i[np.newaxis, :])
np.testing.assert_equal(res, ref)
# passes
In the solution code at the top, the broadcasting is implicit. We don't need to call broadcast_arrays
because it happens automatically during the indexing.