Search code examples
pythonpytorchtorchdataloader

Shuffling along a given axis in PyTorch


I have the a dataset that gets loaded in with the following dimension [batch_size, seq_len, n_features] (e.g. torch.Size([16, 600, 130])).

I want to be able to shuffle this data along the sequence length axis=1 without altering the batch ordering or the feature vector ordering in PyTorch.

Further explanation: For exemplification let's say my batch size is 3, sequence length is 3 and number of features is 2.

example: tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]]) I want to be able to randomly shuffle the following way:

tensor([[[3,3],[1,1],[2,2]],[[6,6],[5,5],[4,4]],[[8,8],[7,7],[9,9]]])

Are there any PyTorch functions that will do that automatically for me, or does anyone know what would be a good way to implement this?


Solution

  • You can use torch.randperm.

    For tensor t, you can use:

    t[:,torch.randperm(t.shape[1]),:]
    

    For your example:

    >>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
    >>> t
    tensor([[[1, 1],
             [2, 2],
             [3, 3]],
    
            [[4, 4],
             [5, 5],
             [6, 6]],
    
            [[7, 7],
             [8, 8],
             [9, 9]]])
    >>> t[:,torch.randperm(t.shape[1]),:]
    tensor([[[2, 2],
             [3, 3],
             [1, 1]],
    
            [[5, 5],
             [6, 6],
             [4, 4]],
    
            [[8, 8],
             [9, 9],
             [7, 7]]])