Search code examples
pythontorch

Reshape a torch tensor by swapping column and row indexes


I have this tensor:

a = torch.tensor([[  101,   101,   101,   101],
        [14812, 16890,  2586,  2586],
        [10337,  1830,  3842,  3842],
        [ 7257, 14541,  3293,  3297]])

How can I reshape it into:

a = torch.tensor([[  101,  14812,  10337,   7257],
        [101, 16890,  1830,  14541],
        [101,  2586,  3842,  3293],
        [ 101, 2586,  3842,  3297]])

Solution

  • You can use torch.T operator.

    a.T
    

    Output:

    tensor([[  101, 14812, 10337,  7257],
            [  101, 16890,  1830, 14541],
            [  101,  2586,  3842,  3293],
            [  101,  2586,  3842,  3297]])