Search code examples
pythonpytorchpermutationtensordifference

`movedim()` vs `moveaxis()` vs `permute()` in PyTorch


I'm completely new to PyTorch, and I was wondering if there's anything I'm missing when it comes to the moveaxis() and movedim() methods. The outputs are the exact same for the same arguments. Also can't both of these methods be replaced by permute()?

An example for reference:

import torch

mytensor = torch.randn(3, 6, 3, 1, 7, 21, 4)

t_md = torch.movedim(mytensor, 2, 5)
t_ma = torch.moveaxis(mytensor, 2, 5)

print(t_md.shape, t_ma.shape)
print(torch.allclose(t_md, t_ma))

t_p = torch.permute(mytensor, (0, 1, 3, 4, 5, 2, 6))

print(t_p.shape)
print(torch.allclose(t_md, t_p))

Solution

  • Yes, moveaxis is an alias of movedim (analogous to swapaxes and swapdims).1

    Yes, this functionality can be achieved with permute, but moving one axis while keeping the relative positions of all others is a common enough use-case to warrant its own syntactic sugar.


    1. The terminology is taken from numpy:

      Alias for torch.movedim().

      This function is equivalent to NumPy’s moveaxis function.