Search code examples
pytorchpytorch-dataloader

How can I return numpy arrays instead of tensors from torch.utils.data.DataLoader


torch.utils.data.DataLoader returns torch.tensors. Is there a way to return numpy arrays instead? The subclass coded below returns still tensors. I would like to change __iter()__ to return numpy arrays. (...my_tensor.numpy(),...) Just cannot figure it out.

class CustomDataLoader(DataLoader):
    def __init__(self, dataset):
        super().__init__(dataset)
    
    def __iter__(self):
        it_ = super().__iter__()
        print( next(it_))
        print(super().__iter__().__dict__)
        return it_ 
        
c = CustomDataLoader(dataset)
next(iter(c))


Solution

  • Yes, you can define your own custom collation function and pass it as Dataloader(dataset,collate_fn=my_function). The collate function is responsible for aggregating or "collating" individual elements of a batch into indexable or iterable batches (e.g. turn a list of n tensors of size [100,100] into a single tensor of size [n,100,100].) If you want to collate your data in non-trivial ways or if you have unusual types in your data, this is often the way to go as pytorch only provides default collate functions for the most common use cases. Within your collate function you could, in the most trivial case, simply convert any tensors to numpy arrays with <tensor>.data.numpy().

    You can check out the docs or this StackOverflow question about defining custom collate functions.

    Hope this helps!