I’m implementing a Transformers architecture from the ground up on 1 dummy sentence.
Here is the code:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, context_size, d_model):
super().__init__()
self.encoding = torch.zeros(context_size, d_model)
pos = torch.arange(0, context_size).unsqueeze(dim=1)
dim = torch.arange(
0, d_model, 2) # dim is i in the positional encoding formula
self.encoding[:, 0::2] = torch.sin(pos / (10000**(2 * dim / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000**(2 * dim / d_model)))
def forward(self, x):
seq_len = x.size(1)
return self.encoding[:seq_len, :]
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
class EncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
hidden_states, _ = self.self_attn(query=x, key=x, value=x)
x = self.norm1(x + hidden_states)
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output)
return x
class Encoder(nn.Module):
# input_size - # rows in token embedding
# context size - # rows in positional embedding
# d_ff - internal dimension of the FF network
# num encoder blocks
def __init__(self, input_size, context_size, d_model, d_ff, num_heads,
n_blocks):
super().__init__()
self.embedding = nn.Embedding(input_size, d_model)
self.pos_embedding = PositionalEncoding(context_size, d_model)
self.blocks = nn.ModuleList([
EncoderBlock(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
) for _ in range(n_blocks)
])
def forward(self, x):
x = self.embedding(x) + self.pos_embedding(x)
for block in self.blocks:
x = block(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads)
self.cross_attn = nn.MultiheadAttention(d_model, num_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_output):
lookahead_mask = torch.triu(torch.ones(x.shape[1], x.shape[1])).bool().transpose(0,1) # lookahead mask shape should be context size (1st dim)
hidden_states, _ = self.self_attn(x, x, x, attn_mask = lookahead_mask)
x = self.norm1(x + hidden_states)
hidden_states, _ = self.cross_attn(
query=x, key=enc_output, value=enc_output)
x = self.norm2(x + hidden_states)
ff_output = self.feed_forward(x)
x = self.norm3(x + ff_output)
return x
class Decoder(nn.Module):
def __init__(self, output_size, context_size,
d_model, d_ff, num_heads, n_blocks):
super().__init__()
self.embedding = nn.Embedding(output_size, d_model)
self.pos_embedding = PositionalEncoding(context_size, d_model)
self.blocks = nn.ModuleList([
DecoderBlock(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
)
for _ in range(n_blocks)
])
self.out = nn.Linear(d_model, output_size)
def forward(self, x, enc_output):
x = self.embedding(x) + self.pos_embedding(x)
for block in self.blocks:
x = block(x, enc_output)
output = self.out(x)
return output
class Transformer(nn.Module):
def __init__(self, vocab_size, context_size,
d_model, d_ff, num_heads, n_blocks):
super().__init__()
self.encoder = Encoder(
vocab_size,
context_size,
d_model,
d_ff,
num_heads,
n_blocks
)
self.decoder = Decoder(
vocab_size,
context_size,
d_model,
d_ff,
num_heads,
n_blocks
)
def forward(self, input_encoder, input_decoder):
enc_output = self.encoder(input_encoder) # (64, 100, 10)
output = self.decoder(input_decoder, enc_output) # input_decoder shape - (64, 99)
return output
SOS_token = 0
EOS_token = 1
PAD_token = 2 # Need to have padding so that the input & output sentences
# are the same length - required for the cross-attention computation
index2words = {
SOS_token: 'SOS',
EOS_token: 'EOS',
PAD_token: 'PAD'
}
words = "The animals didn't like living in the zoo"
words_list = set(words.lower().split(' '))
for word in words_list:
index2words[len(index2words)] = word
words2index = {w: i for i, w in index2words.items()}
def convert2tensors(sentence, max_len):
words_list = sentence.lower().split(' ')
padding = ['PAD'] * (max_len - len(words_list))
words_list.extend(padding)
indexes = [words2index[word] for word in words_list]
return torch.tensor(indexes, dtype=torch.long).view(1, -1)
D_MODEL = 10
VOCAB_SIZE = len(words2index)
N_BLOCKS = 10
D_FF = 20
CONTEXT_SIZE = 100
NUM_HEADS = 2
transformer = Transformer(
vocab_size=VOCAB_SIZE,
context_size=CONTEXT_SIZE,
d_model=D_MODEL,
d_ff=D_FF, # internal dimension of the feed forward network
num_heads=NUM_HEADS,
n_blocks=N_BLOCKS
)
input_sentence = "The animals"
output_sentence = "didn't like"
input_encoder = convert2tensors(input_sentence, CONTEXT_SIZE)
input_decoder = convert2tensors(output_sentence, CONTEXT_SIZE)
output_toy = transformer(input_encoder, input_decoder)
I’m adding a lookahead mask in the DecoderBlock and I get an error on this line of the forward method.
hidden_states, _ = self.self_attn(x, x, x, attn_mask = lookahead_mask)
RuntimeError: The shape of the 2D attn_mask is torch.Size([100, 100]), but should be (1, 1).
Why should the shape be (1,1)? x has shape (1,100,10) - batch, context size, d_model. lookahead_mask has shape (100,100).
Also going through this, realized not sure whether it 's ok to apply the attention mask after summing the embeddings with the positional embeddings.
You need to pass batch_first=True
to MultiheadAttention