Search code examples
pythonpytorchlarge-language-model

LLM output repeating itself


I am currently following this tutorial on making a basic LLM that spews Shakespeare like text(The full code for the transformer is at the the end). I am at the end but when I train it and get an output the output just keeps repeating itself with the same stuff. Here is my code

import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
from functions.encode import encode_chars
from functions.character_amount import character_amount
from functions.train_test_split import train_test_split
from functions.decoding import decoding
with open(r'example_shakespeare_text.txt') as file:
    file = file.read()
split = (file.split('\n'))
max_iters = 25
num_embed = 64
num_heads = 16
num_layers = 8
batch_size = 32
block_size = 128
dropout = 0.2
learning_rate = 1e-3

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

encode = tiktoken.get_encoding('gpt2')

characters = character_amount(encode=encode, split=split)
vocab_size = encode.n_vocab
    
encoded = encode_chars(split=split, encode=encode)

data = torch.tensor(encoded, dtype=torch.long)
train_data, test_data = train_test_split(data=data)

def array_creation(split):
    if split == 'train':
        data = train_data
    else:
        data = test_data

    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size + 1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y
        
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(num_embed, head_size, bias=False)
        self.query = nn.Linear(num_embed, head_size, bias=False)
        self.value = nn.Linear(num_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        B, T, C = x.shape
        head_size = 16
        key = nn.Linear(C, head_size, bias=False)
        query = nn.Linear(C, head_size, bias=False)
        k = key(x)
        q = query(x)
        weight =  q @ k.transpose(-2,-1) * C **-0.5
        weight = weight.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        weight = F.softmax(weight, dim=-1)
        weight = self.dropout(weight)

        v = self.value(x)
        out = weight @ v
        return out

class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.prj = nn.Linear(num_embed, num_embed)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.prj(out))
        return out

class FeedForward(nn.Module):
    def __init__(self, num_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_embed,  4 * num_embed),
            nn.ReLU(),
            nn.Linear(4 * num_embed, num_embed),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    def __init__(self, num_embed,num_heads):
        super().__init__()
        head_size = num_embed // num_heads
        self.sa = MultiHead(num_heads, head_size)
        self.ffwd = FeedForward(num_embed)
        self.layernorm1 = nn.LayerNorm(num_embed)
        self.layernorm2 = nn.LayerNorm(num_embed)

    def forward(self, x):
        x = x + self.sa(self.layernorm1(x))
        x = x + self.ffwd(self.layernorm2(x))
        return x

class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, num_embed)
        self.position_embedding_table = nn.Embedding(block_size, num_embed)
        self.blocks = nn.Sequential(*[Block(num_embed, num_heads=num_heads) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(num_embed)
        self.lm_head = nn.Linear(num_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx)
        position_embedding = self.position_embedding_table(torch.arange(T, device=device))
        x = token_emb + position_embedding
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        if targets != None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

m = BigramLanguageModel()
model = m.to(device)

generated_list = model.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()
decoded_list = decoding(generated_list=generated_list, encode=encode)
    
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

iteration = 0
for _ in range(max_iters):
    xy, yb = array_creation('train')
    
    logits, loss = model(xy, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    iteration += 1

    print(iteration)
    print(loss.item())

context =  torch.zeros((1,1), dtype=torch.long, device=device)
print(decoding(generated_list=model.generate(context,max_new_tokens=100)[0].tolist(), encode=encode))

Here is the output

A', '! re al, we hear me speak.All:Speak.First Citizen:You are all resolved rather to die to than famish?A', '! re al, we hear me speak.All:Speak.First Citizen:You are all resolved rather to die to than famish?A', '! re al, we hear me speak.All:Speak.First Citizen:You are all resolved rather to die to than famish?A', '! re al, we hear me speak.All:Speak.First Citizen:You are all resolved rather to die to than famish?

It keeps repeating itself even farther than that.

I tried to increase the amount of data going in but that didn't help, I also tried to change the amount of iterations and batch size/block size. But it still didnt change the repetition.

Do I just need to do even more intense training?


Solution

  • It was a problem with my decoding function, not sure what but I made my own tokenizer rather than using tiktoken and it fixed the problem.