Search code examples
pythontensorflowkerasword2vecword-embedding

Keras - Translation from Sequential to Functional API


I've been following Towards Data Science's tutorial about word2vec and skip-gram models, but I stumbled upon a problem that I cannot solve, despite searching about it a lot and trying multiple unsuccessful solutions.

https://towardsdatascience.com/understanding-feature-engineering-part-4-deep-learning-methods-for-text-data-96c44370bbfa

The step that it shows you how to build the skip-gram model architecture seems deprecated because of the use of the Merge layer from keras.layers.

What I tried to do was translate his piece of code - which is implemented in the Sequential API of Keras - to the Functional API to solve the deprecation of the Merge layer, by replacing it with the keras.layers.Dot layer. However, I'm still stuck in this step of merging the two models (word and context) into the final model, whose architecture must be like this:

Skip-gram model summary and architecture

Here's the code that the author used:

from keras.layers import Merge
from keras.layers.core import Dense, Reshape
from keras.layers.embeddings import Embedding
from keras.models import Sequential

# build skip-gram architecture
word_model = Sequential()
word_model.add(Embedding(vocab_size, embed_size,
                         embeddings_initializer="glorot_uniform",
                         input_length=1))
word_model.add(Reshape((embed_size, )))

context_model = Sequential()
context_model.add(Embedding(vocab_size, embed_size,
                  embeddings_initializer="glorot_uniform",
                  input_length=1))
context_model.add(Reshape((embed_size,)))

model = Sequential()
model.add(Merge([word_model, context_model], mode="dot"))
model.add(Dense(1, kernel_initializer="glorot_uniform", activation="sigmoid"))
model.compile(loss="mean_squared_error", optimizer="rmsprop")

And here is my attempt to translate the Sequential code implementation into the Functional one:

from keras import models
from keras import layers
from keras import Input, Model

word_input = Input(shape=(1,))
word_x = layers.Embedding(vocab_size, embed_size, embeddings_initializer='glorot_uniform')(word_input)
word_reshape = layers.Reshape((embed_size,))(word_x)

word_model = Model(word_input, word_reshape)    

context_input = Input(shape=(1,))
context_x = layers.Embedding(vocab_size, embed_size, embeddings_initializer='glorot_uniform')(context_input)
context_reshape = layers.Reshape((embed_size,))(context_x)

context_model = Model(context_input, context_reshape)

model_input = layers.dot([word_model, context_model], axes=1, normalize=False)
model_output = layers.Dense(1, kernel_initializer='glorot_uniform', activation='sigmoid')

model = Model(model_input, model_output)

However, when executed, the following error is returned:

ValueError: Layer dot_5 was called with an input that isn't a symbolic tensor. Received type: . Full input: [, ]. All inputs to the layer should be tensors.

I'm a total beginner to the Functional API of Keras, I will be grateful if you could give me some guidance in this situation on how could I input the context and word models into the dot layer to achieve the architecture in the image.


Solution

  • You are passing Model instances to the layer, however as the error suggests you need to pass Keras Tensors (i.e. outputs of layers or models) to layers in Keras. You have two option here. One is to use the .output attribute of the Model instance like this:

    dot_output = layers.dot([word_model.output, context_model.output], axes=1, normalize=False)
    

    or equivalently, you can use the output tensors directly:

    dot_output = layers.dot([word_reshape, context_reshape], axes=1, normalize=False)
    

    Further, you need to apply the Dense layer which is followed on the dot_output and pass instances of Input layer as inputs of Model. Therefore:

    model_output = layers.Dense(1, kernel_initializer='glorot_uniform',
                                activation='sigmoid')(dot_output)
    
    model = Model([word_input, context_input], model_output)