Search code examples
pythonarraysnumpy

Apply permutation array on multiple axes in numpy


Let's say I have an array of permutations perm which could look like:

perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])

If I want to apply it to one axis, I can write something like:

v = np.arange(9).reshape(3, 3)
print(v[perm])

Output:

array([[[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]],

       [[3, 4, 5],
        [6, 7, 8],
        [0, 1, 2]],

       [[0, 1, 2],
        [6, 7, 8],
        [3, 4, 5]],

       [[6, 7, 8],
        [3, 4, 5],
        [0, 1, 2]]])

Now I would like to apply it to two axes at the same time. I figured out that I can do it via:

np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])

But I find it quite inefficient, because it has to create a mesh grid, and it also requires a for loop. I made a small array in this example but in reality I have a lot larger arrays with a lot of permutations, so I would really love to have something that's as quick and simple as the one-axis version.


Solution

  • How about:

    p1 = perm[:, :, np.newaxis]
    p2 = perm[:, np.newaxis, :]
    v[p1, p2]
    

    The zeroth axis of p1 and p2 is just the "batch" dimension of perm, which allows you to do many permutations in one operation.

    The other dimension of perm, which corresponds with the indices, is aligned along the first axis in p1 and the second in p2. Because the axes are orthogonal, the arrays get broadcasted, basically like the arrays you got using meshgrid - but these still have the batch dimension.

    That's the best I can do from my cell phone : ) I can try to clarify later if needed, but the key idea is broadcasting.

    Comparison:

    import numpy as np
    perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])
    v = np.arange(9).reshape(3, 3)
    
    ref = np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
    
    p1 = perm[:, :, np.newaxis]
    p2 = perm[:, np.newaxis, :]
    res = v[p1, p2]
    
    np.testing.assert_equal(res, ref)
    # passes
    
    %timeit np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
    # 107 µs ± 20.6 µs per loop
    
    %timeit v[perm[:, :, np.newaxis], perm[:, np.newaxis, :]]
    # 3.73 µs ± 1.07 µs per loop
    

    A simpler (without batch dimension) example of broadcasting indices:

    import numpy as np
    i = np.arange(3)
    ref = np.meshgrid(i, i, indexing="ij")
    res = np.broadcast_arrays(i[:, np.newaxis], i[np.newaxis, :])
    np.testing.assert_equal(res, ref)
    # passes
    

    In the solution code at the top, the broadcasting is implicit. We don't need to call broadcast_arrays because it happens automatically during the indexing.