Search code examples
pythontextnlppaddingdata-preprocessing

Varying embedding dim due to changing padding in batch size


I want to train a simple neural network, which has embedding_dim as a parameter:

class BoolQNN(nn.Module):
    def __init__(self, embedding_dim):
        super(BoolQNN, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 1)

    def forward(self, question_emb, passage_emb):
        combined = torch.cat((question_emb, passage_emb), dim=1)
        x = self.fc1(combined)
        x = self.relu(x)
        x = self.fc2(x)
        return torch.sigmoid(x)

To load the data I used torchs DataLoader with a custom collate_fn.

train_dataset = BoolQDataset(train_data, pretrained_embeddings)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True,collate_fn=collate_fn_padd)

model = BoolQNN(301)

The collate_fn_padd function looks the following:

def collate_fn_padd(batch):

  questions, passages, labels = zip(*batch)

  questions = [torch.tensor(q) for q in questions]
  passages = [torch.tensor(p) for p in passages]

  padded_questions = pad_sequence(questions, batch_first=True, padding_value=0)
  padded_passages = pad_sequence(passages, batch_first=True, padding_value=0)

  labels = torch.tensor(labels, dtype=torch.float32)
  
  return padded_questions, padded_passages, labels

The problem: For every batch I want to train my model with, the embedded text gets padded differently long (it takes the longest sequence of the current batch).

That means that my embedding dim/input size for the linear layer in my neural network changes from batch to batch, althoug I want the size to be the same for every batch.

Due to that, I receive errors like that: mat1 and mat2 shapes cannot be multiplied (16x182 and 301x64)

Is it possible to adjust the collate_fn_pad function so that it padds the sequence the same size, independet of the batch size?


Solution

  • You can add a maximum length argument set to embedding_dim to pad and truncate all the data to a fixed length:

    padded_questions = [torch.nn.functional.pad(torch.tensor(q), (0, max_length - len(q)), value=0)[:max_length] for q in questions]
    padded_passages = [torch.nn.functional.pad(torch.tensor(p), (0, max_length - len(p)), value=0)[:max_length] for p in passages]