Search code examples
pythonpytorchtorchtransformer-modelpytorch-dataloader

Split torch dataset without shuffling


I'm using Pytorch to run Transformer model. when I want to split data (tokenized data) i'm using this code:

train_dataset, test_dataset = torch.utils.data.random_split(
                                                            tokenized_datasets,
                                                            [train_size, test_size])

torch.utils.data.random_split using shuffling method, but I don't want to shuffle. I want to split it sequentially.

Any advice? thanks


Solution

  • The random_split method has no parameter that can help you create a non-random sequential split.

    The easiest way to achieve a sequential split is by directly passing the indices for the subset you want to create:

    # Created using indices from 0 to train_size.
    train_dataset = torch.utils.data.Subset(tokenized_datasets, range(train_size))
    
    # Created using indices from train_size to train_size + test_size.
    test_dataset = torch.utils.data.Subset(tokenized_datasets, range(train_size, train_size + test_size))
    

    Refer: PyTorch docs.