Search code examples
deep-learningpytorchlstmtransformer-modellanguage-model

Why is my Transformer implementation losing to a BiLSTM?


I am dealing with a sequence tagging problem and I am using a single Transformer Encoder to obtain logits from each element of the sequence. Having experimented both with Transformer and BiLSTM it looks like in my case BiLSTM is working better, so I was wondering if maybe it is because my Transformer implementation has some problem... Below is my implementation of the Transformer Encoder and related functions for creating padding mask and positional embeddings:

def create_mask(src, lengths):
    """Create a mask hiding future tokens
    Parameters:
        src (tensor): the source tensor having shape [batch_size, number_of_steps, features_dimensions]
        length (list): a list of integers representing the length (i.e. number_of_steps) of each sample in the batch."""
    mask = []
    max_len = src.shape[1]
    for index, i in enumerate(src):
        # The mask consists in tensors having false at the step number that doesn't need to be hidden and true otherwise
        mask.append([False if (i+1)>lengths[index] else True for i in range(max_len)])
    return torch.tensor(mask)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000, device = 'cpu'):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.device = device
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(self.device)
        return self.dropout(x)

class Transformer(nn.Module):
    """Class implementing transformer ecnoder, partially based on
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html"""
    def __init__(self, in_dim, h_dim, n_heads, n_layers, dropout=0.2, drop_out = 0.0, batch_first = True, device = 'cpu', positional_encoding = True):
        super(Transformer, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(in_dim, dropout, device = device)
        encoder_layers = nn.TransformerEncoderLayer(in_dim, n_heads, h_dim, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers, norm=nn.LayerNorm(in_dim))
        self.in_dim = in_dim
        self.drop_out = drop_out
        self.positional_encoding = positional_encoding
    
        
    def forward(self, src, mask = None, line_len=None):
        src = src * math.sqrt(self.in_dim)
        if self.positional_encoding:
            src = self.pos_encoder(src)
        if line_len is not None and mask is None:
            mask = create_mask(src, line_len)
        else:
            mask = None
        output = self.transformer_encoder(src, src_key_padding_mask = mask)
        if self.drop_out:
            output = F.dropout(output, p = self.drop_out)
        return src, output

As it can be seen, the above network outputs the hidden states and then I pass them into an additional linear layer and train with a CrossEntropy loss over two classes and Adam optimizer. I have tried multiple combinations of hyperparameters but the BiLSTM still performs better. Can anyone spot anything off in my Transformer or suggest why I experience such a counterintuitive result?


Solution

  • This may be surprising, but Transformers don't always beat LSTMs. For example, Language Models with Transformers states:

    Transformer architectures are suboptimal for language model itself. Neither self-attention nor the positional encoding in the Transformer is able to efficiently incorporate the word-level sequential context crucial to language modeling.

    If you run the Transformer tutorial code itself (on which your code is based), you'll also see LSTM do better there. See this thread on stats.SE for more discussion on this topic (disclaimer: both the question and the answer there are mine)