Search code examples
pythonneural-networkpytorchbert-language-modelsiamese-network

Pytorch Siamese NN with BERT for sentence matching


I'm trying to build a Siamese neural network using pytorch in which I feed BERT word embeddings and trying to find whether two sentences are similar or not (imagine duplicate posts matching, product matching etc). Here's the model:

class SiameseNetwork(torch.nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.brothers = torch.nn.Sequential(
            torch.nn.Linear(512 * 768, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(512, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(256, 32),
        )
        
        self.final = torch.nn.Sequential(
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(16, 2),
        )
    
    def forward(self, left, right):
        outputLeft = self.brothers(left)
        outputRight = self.brothers(right)
        output = self.final((outputLeft - outputRight) ** 2)
        return output

bros = SiameseNetwork()
bros = bros.to(device)

Criterion and optimizer:

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=bros.parameters(), lr=0.001)

The training loop:

for batch in tqdm(tLoader, desc=f"Train epoch: {epoch+1}"):
        a = batch[0].to(device)
        b = batch[1].to(device)
        y = torch.unsqueeze(batch[2].type(torch.FloatTensor), 1).to(device)
        
        optimizer.zero_grad()
        
        output = bros(a,b)
        loss = criterion(output, y)
        loss.backward()
        
        trainingLoss += loss.item()

        optimizer.step()

Now, this seems to be working, as it produces results that are reasonable, but the validation error stops dropping at 0.13 after just a of epochs. Can't find a lot of things on this kind of NNs using Pytorch. Are there ways to optimize it? Am I doing something wrong?


Solution

  • Your first layer is severely overparameterized and prone to overfitting (counts a total of 201 million parameters). I assume the shape 512 * 768 reflects the number of tokens times their dimensionality; if that's the case, you need to rethink your architecture. You need some sort of weight sharing or pooling strategy to reduce the num_words * dim input to a fixed representation (that's exactly why recurrent networks replaced the fully-connected varieties for sentence encoding). Specifically in transformer-based architectures, the [CLS] token (token number 0, prefixing the input) is typically used as the "summary" token for sequence- and bisequence-level tasks.