Search code examples
pytorchdatasetdataloader

DataLoader returns multiple values sequentially instead of a list or tuple


def __init__():

def __len__():

def __getitem__(self, idx):    
    cat_cols = (self.cat_cols.values.astype(np.float32))
    cont_cols = (self.cont_cols.values.astype(np.float32))
    label = (self.label.astype(np.int32))
    return (cont_cols[idx], cat_cols[idx], label[idx])

When I used the dataloader in the above class, I get the cont_cols, cat_cols and label as outputs with index 0, 1 and 2. Whereas I want them together. I have tried returning values as dictionary but then I have indexing issues.

I have to read the output of dataloader as

dl = DataLoader(dataset[0], batch_size = 1)



for i, data in enumerate(dl):
    if i == 0:
       cont = data
    if i == 1:
       cat = data
    if i == 2:
       label = data

Currently my output for

for i, data in enumerate(dl):
   print(i, data) 

is

0 tensor([[3.2800e+02, 4.8000e+01, 1.0000e+03, 1.4069e+03, 4.6613e+05, 5.3300e+04, 0.0000e+00, 5.0000e+00, 1.0000e+00, 1.0000e+00, 2.0000e+00, 7.1610e+04, 6.5100e+03, 1.3020e+04, 5.2080e+04, 2.0040e+03]])

1 tensor([[ 2., 1., 1., 4., 2., 17., 0., 2., 3., 0., 4., 4., 1., 2., 2., 10., 1.]])

2 tensor([1], dtype=torch.int32)

What I want is the output to be accessed by data[0], data[1] and data[2] but the dataloader gives me back only data[0]. It is returning the cont_cols first, then cat_cols and then label.


Solution

  • I think you got confused here, your dataset can indeed return tuples but you have to handle it differently.

    Your dataset is defined as:

    class MyDataset(Dataset):
        def __init__(self):
            pass
    
        def __len__():
            pass
    
        def __getitem__(self, idx):    
            cat_cols = (self.cat_cols.values.astype(np.float32))
            cont_cols = (self.cont_cols.values.astype(np.float32))
            label = (self.label.astype(np.int32))
            return (cont_cols[idx], cat_cols[idx], label[idx])
    

    Then you define your dataset and data loader. Note, you should not provide dataset[0] here, but instead dataset:

    >>> dataset = Dataset()
    >>> dl = DataLoader(dataset, batch_size=1)
    

    Then access your dataloader content in a loop:

    >>> for cont, cat, label in dl:
    ...   print(cont, cat, label)