Search code examples
pythonnlppytorch

Create iterator from a Data Frame in Python


I am working on an NLP project using Seq2Seq. I created a data frame from my dataset then created a batch iterator using data loader, see the following code:

# creates lists containing each pair
original_word_pairs = [[w for w in l.split('\t')] for l in lines[:num_examples]]
data = pd.DataFrame(original_word_pairs, columns=["src", "trg"])

# conver the data to tensors and pass to the Dataloader 
# to create a batch iterator

class MyData(Dataset):
    def __init__(self, X, y):
        self.data = X
        self.target = y
        # TODO: convert this into torch code is possible
        self.length = [ np.sum(1 - np.equal(x, 0)) for x in X]
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        x_len = self.length[index]
        return x,y,x_len
    
    def __len__(self):
        return len(self.data)

train_dataset = MyData(input_tensor_train, target_tensor_train)
val_dataset = MyData(input_tensor_val, target_tensor_val)

train_dataset = DataLoader(train_dataset, batch_size = BATCH_SIZE, 
                     drop_last=True,
                     shuffle=True)
test_dataset= DataLoader(val_dataset, batch_size = BATCH_SIZE, 
                     drop_last=True,
                     shuffle=True)

That is a part of my code, the thing is I want to use the iterators like this

for i, batch in enumerate(iterator):
        
        src = batch.src
        trg = batch.trg

But I got an error "AttributeError: 'list' object has no attribute 'src'" How can I use the iterator and access a specific column?


Solution

  • You can redefine __getitem__ in your Dataset to return a dictionary:

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        x_len = self.length[index]
        return {"src": x, "trg": y, "x_len": x_len}
    

    The default collate_fn of DataLoader will take care to provide a dictionary containing batches instead of single observations, but you need to convert x_len to a tensor into __getitem__ to make it work (or you can pass a custom collate_fn).