Search code examples
torcheinops

Einops rearrange function basic functionallity


I'm trying to grok the einops syntax for tensor reordering, but am somehow missing the point

If I have the following matrix:

mat = torch.randint(1, 10, (8,4))

I understand what the following command does:

rearrange(mat, '(h n) w -> (n h) w', n = 2)

But can't really wrap my head around the following ones:

rearrange(mat, '(n h) w -> (h n) w', n = 2)
rearrange(mat, '(n h) w -> (h n) w', n = 4)

Any help would be appreciated


Solution

  • rearrange(mat, '(h n) w -> (n h) w', n = 2)
    and
    rearrange(mat, '(n h) w -> (h n) w', n = 2)
    

    are inversions of each other. If you can imagine what one does, second makes reverse transform

    As for the latter, mat is 8x4

    rearrange(mat, '(n h) w -> (h n) w', n = 4)
    

    So you first split first dimension in 4x2 (below I ignore w dimension, because nothing special happens with it)

    [0, 1, 2, 3, 4, 5, 6, 7]
    

    to

    [0, 1, 
     2, 3, 
     4, 5, 
     6, 7]
    

    then you change order of axes to 2x4 (transpose)

    [0, 2, 4, 6,
     1, 3, 5, 7]
    

    then merge two dimensions into one

    [0, 2, 4, 5, 1, 3, 5, 7]
    

    If you still don't feel how that works, take simpler examples like

    rearrange(np.arange(50), '(h n) -> h n', h=5)
    rearrange(np.arange(50), '(h n) -> h n', h=10)
    rearrange(np.arange(50), '(h n) -> n h', h=10)
    

    etc. So that you could track movement of each element in the matrix