Search code examples
deep-learningpytorchsimpletransformers

Custom Multihead Attention class leaks data for causal attention


I have been working on implementing a custom multi-head attention class in PyTorch for a Transformer model for learning purposes. My implementation lacks any functionality, I just want to make it work for a base case. I've noticed that for causal attention (tokens can't attend to future tokens) my model seems to suffer from data leakage. I've come to that conclusion after testing the same script with the torch nn.MultiheadAttention class.

To me, it seems that the problem is in the way that I apply the mask, but I can't really find the problem. I've tested that a two dimensional masks broadcast properly to 4 dimensional tensors (which is my approach). I've verified several times that the right tokens are masked to no avail.

This is the code

class MultiHeadAttention(nn.Module):

    def __init__(self, n_heads, d_model,  dropout=0.1):

        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.query = nn.Linear(d_model, d_model, bias=False)
        self.key = nn.Linear(d_model, d_model, bias=False)
        self.value = nn.Linear(d_model, d_model, bias=False)
        self.att_proj = nn.Linear(d_model, d_model, bias=False)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())

    def forward(self, x):

        q = x
        k = x
        v = x
        B,T,C = x.shape 
        dk = d_model // n_heads

        # linear projections
        q = self.query(q) 
        k = self.key(k) 
        v = self.value(v) 

        # add number of heads
        q = q.view(B,T,n_heads,dk).permute(0,2,1,3)   # B,T,h,dk
        k = k.view(B,T,n_heads,dk).permute(0,2,1,3)  
        v = v.view(B,T,n_heads,dk).permute(0,2,1,3)  
        
        # attention 

        x = q @ k.transpose(-2,-1) # B,h,T,dk @ B,h,dk,T --> B,h,T,T
        x = x * dk ** -0.5 # B,h,T,T
        x = x.masked_fill(self.mask, float('-inf')) # B,h,T,T
        x = F.softmax(x, dim=(-1)) # B,n_h,T,T 
        x = x @ v  # B,h,T,T @ B,T,h,dv --> B,h,T,dv
        x = x.view(B,T,-1)
        out = self.att_proj(x) # B,T,C

        return out```

With a toy example I quickly get to Losses such as Training Loss: 2.307. Evaluation Loss: 2.278. When using  torch the losses are far less ambitious Iteration 9999. Training Loss: 2.469. Evaluation Loss: 2.483. What am I missing?

This is my model implementation just in case the error is here

class Model(nn.Module):

def __init__(self, vocab_size, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)

    self.embedding_table = nn.Embedding(vocab_size, d_model)
    self.mha = MultiHeadAttention(n_heads, d_model)
    self.out = nn.Linear(d_model, vocab_size, bias=False)

def forward(self, x, targets=None):

    x = self.embedding_table(x)
    B, T, C = x.shape
    
    x = self.mha(x) # B,T,C
    logits = self.out(x) # B,T,vocab_size

    if targets is not None:
        logits = logits.reshape(-1, logits.shape[-1])
        targets = targets.reshape(-1)
        loss = F.cross_entropy(input=logits, target=targets)
    else:
        loss = None

    return logits, loss

def generate(self, n_chars, ix):

    for _ in range(n_chars):

        logits, loss = self(ix) # B, T, C
        logits = logits[:,-1,:] # B, C -- we need to reshape to calculate probabilities
        probs = F.softmax(logits, dim=-1) # B, C
        next_ix = torch.multinomial(input=probs, num_samples=1)
        ix = torch.cat((ix, next_ix), dim=1)

    return ix```

I've tried using a different train and validation split methods, to make sure the leakage wasn't happening here. Then, I've tried several masking approaches, using tril and filling 0s with -inf or triu filling Trues with -inf. I have made sure diagonal is 1 so that only future tokens are masked


Solution

  • I have tentatively found the problem. I was reshaping one of the intermediate results in the wrong way

    I could not do after v is calculated

    x = x.view(B,T,-1)
    

    Instead I should do

    B,h,T,dv = x.shape
    x = x.transpose(2,1).contiguous().view(B,T,h*dv) #B,T,C