I have trained a transformer model on language modeling (i.e predicting the next character given the context) on a dataset. CONTEXT_LENGTH = 200
, I want the model to predict when the input is not of length CONTEXT_LENGTH
, so how do I have to modify my code so I can predict on varied input shape and also help me in writing code for a function to generate next characters.
class Embed(keras.layers.Layer):
"""word_embedding + positional_embedding """
def __init__(self):
super().__init__()
self.word_embed = keras.layers.Embedding(VOCAB_SIZE, d_model) # (B, T) =(vocab_size, d_model)=> (B, T, d_model)
self.position_embed = keras.layers.Embedding(CONTEXT_LENGTH, d_model) # (B, T) =(CONTEXT_LENGTH, d_model)=> (B, T, d_model)
def call(self, inputs):
B, T = inputs.shape # when training CONTEXT_LENGTH = T
tok_embed = self.word_embed(inputs) # (B, T, d_model)
pos_embed = self.position_embed(tf.range(T)) # (T, d_model)
return tok_embed + pos_embed # (B, T, d_model)
def get_config(self):
base_config = super().get_config()
return {**base_config}
class MultiHeadAttention(keras.layers.Layer):
def __init__(self, mask: bool):
super().__init__()
self.mask = mask
self.linear = keras.layers.Dense(d_model, use_bias=False)
self.linearqkv = keras.layers.Dense(d_k, use_bias=False), keras.layers.Dense(d_k, use_bias=False), keras.layers.Dense(d_v, use_bias=False)
self.dropout = keras.layers.Dropout(0.1)
def attention(self, Q, K, V):
def mask_tensor(x):
tril = tf.experimental.numpy.tril(tf.ones_like(x))
return tf.where(tril==0, float('-inf'), x)
scores = Q @ tf.transpose(K, perm=[0, 2, 1])/K.shape[-1]**0.5 # (B, T, T)
scores = mask_tensor(scores) if self.mask else scores
return tf.nn.softmax(scores, axis=-1) @ V # (B, T, d_v)
def head(self, X):
Q, K, V = self.linearqkv[0](X), self.linearqkv[1](X), self.linearqkv[2](X)
return self.attention(Q, K, V)
def call(self, X):
heads = tf.concat([self.head(X) for _ in range(h)], axis=-1)
output = self.linear(heads)
output = self.dropout(output)
return output
def get_config(self):
base_config = super().get_config()
return {**base_config, "mask": self.mask}
def FeedForward():
return keras.Sequential([
keras.layers.Dense(d_in),
keras.layers.ReLU(),
keras.layers.Dense(d_model),
keras.layers.Dropout(0.2)
])
inputs = keras.Input(shape=(200,))
X = Embed()(inputs)
for _ in range(N):
Z = MultiHeadAttention(mask=True)(X)
X = keras.layers.LayerNormalization()(Z + X)
Z = FeedForward()(X)
X = keras.layers.LayerNormalization()(Z + X)
outputs = keras.layers.Dense(VOCAB_SIZE, activation="softmax")(X) # (B, T, VOCAB_SIZE)
model = keras.Model(inputs=inputs, outputs=outputs, name="transformer")
I think maybe there's a problem in the Embed
layer, when adding tok_embed
and
pos_embed
. I think it can be modified so it can take inputs of varied length.
Padding can affect the model performance, so is there any other way?
Please help, thank you.
Edit: There was no problem in training, got a good accuracy.
I have changed the code so the transformer model can take inputs of varied length.
def transformer():
class Embed(keras.layers.Layer):
"""word_embedding + positional_embedding """
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.word_embed = keras.layers.Embedding(VOCAB_SIZE, d_model) # (B, T) =(vocab_size, d_model)=> (B, T, d_model)
self.position_embed = keras.layers.Embedding(MAX_LENGTH, d_model) # (B, T) =(MAX_LENGTH, d_model)=> (B, T, d_model)
def call(self, inputs):
B, T = inputs.shape # if training, T = MAX_LENGTH
tok_embed = self.word_embed(inputs) # (B, T, d_model)
pos_embed = self.position_embed(tf.range(MAX_LENGTH)) # (MAX_LENGTH, d_model) =[:T, :]=> (T, d_model)
return tok_embed + pos_embed[:T, :] # (B, T, d_model) + (T, d_model) ==> (B, T, d_model)
def get_config(self):
base_config = super().get_config()
return {**base_config}
class MultiHeadAttention(keras.layers.Layer):
def __init__(self, causal: bool, **kwargs):
super().__init__(**kwargs)
self.causal = causal
self.linear = keras.layers.Dense(d_model, use_bias=False)
self.linearqkv = [keras.layers.Dense(d_k, use_bias=False),
keras.layers.Dense(d_k, use_bias=False),
keras.layers.Dense(d_v, use_bias=False)]
self.dropout = keras.layers.Dropout(0.1)
def attention(self, Q, K, V):
def mask_tensor(x):
tril = tf.experimental.numpy.tril(tf.ones_like(x))
return tf.where(tril==0, float('-inf'), x)
scores = Q @ tf.transpose(K, perm=[0, 2, 1])/K.shape[-1]**0.5 # (B, T, T)
scores = mask_tensor(scores) if self.causal else scores
return tf.nn.softmax(scores, axis=-1) @ V # (B, T, d_v)
def head(self, X):
Q, K, V = self.linearqkv[0](X), self.linearqkv[1](X), self.linearqkv[2](X)
return self.attention(Q, K, V)
def call(self, X):
heads = tf.concat([self.head(X) for _ in range(h)], axis=-1)
output = self.linear(heads)
return self.dropout(output)
def get_config(self):
base_config = super().get_config()
return {**base_config, "causal": self.causal}
def FeedForward():
return keras.Sequential([
keras.layers.Dense(d_in),
keras.layers.ReLU(),
keras.layers.Dense(d_model),
keras.layers.Dropout(0.1)
])
inputs = keras.Input(shape=(None,)) # so can take inputs of varied length
x = Embed()(inputs)
for _ in range(N): # transformer's decoder
z = MultiHeadAttention(causal=True)(x)
x = keras.layers.LayerNormalization()(keras.layers.Add()([z, x]))
z = FeedForward()(x)
x = keras.layers.LayerNormalization()(keras.layers.Add()([z, x]))
outputs = keras.layers.Dense(VOCAB_SIZE, activation="softmax")(x) # (B, T, VOCAB_SIZE)
model = keras.Model(inputs=inputs, outputs=outputs, name="transformer")
print("number of parameters in the model", model.count_params())
return model
model.predict(...)
for generation, use model(..., training=False)
def generate_file(prompt: str, num_char: int, temperature=1):
def next_char(seq):
return sentence(tf.argmax(model(np.array(encode(seq))[np.newaxis], training=False)/temperature, axis=-1)[0].numpy().tolist())[-1]
seq = prompt
for i in range(num_char):
if len(seq) >= MAX_LENGTH:
seq += next_char(seq[-(MAX_LENGTH-1):]) # last MAX_LENGTH-1 characters so can predict char at MAX_LENGTH
elif len(seq) < MAX_LENGTH:
seq += next_char(seq)
print(seq)
with open(“machine_generated_text.txt”, ”w”) as f:
f.write(seq)