Search code examples
kerasembedding

keras - embedding layer, can I alter values of a trained embedding layer in the pipeline of a model?


I am following examples on this page: https://machinelearningmastery.com/use-word-embedding-layers-deep-learning-keras/

which trains a word embedding on the data using an Embedding layer, like below:

model = Sequential()
model.add(Embedding(100, 8, input_length=max_length))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
# compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
# summarize the model
print(model.summary())

the model starts with learning a word embedding from data, for each word, creates a 8-dimension vector.

What I would like to do, is that after this embedding is learned, I want to alter the matrix (or vectors of each word), by adding two more dimensions appended to the end of each vector. I will have another process that computes the values for this two dimensions.

Is there anyway I can do this?

Many thanks in advance


Solution

  • Yes - it's possible. Try to do this using following procedure:

    1. Extract weight matrix:

      weight_matrix = model.layers[0].get_weights()[0] # Matrix shape (100, 8).
      
    2. Append your vectors:

      new_weight_matrix = your_append(weight_matrix)
      # Be sure that new_weight_matrix has shape of (100, 10)
      
    3. Build an adjusted copy of your model:

      new_model = Sequential()
      new_model.add(Embedding(100, 10, input_length=max_length)) # Notice a change
      new_model.add(Flatten())
      new_model.add(Dense(1, activation='sigmoid'))
      
    4. (Optional) freeze layers: In case you want to freeze embedding set:

      new_model = Sequential()
      new_model.add(Embedding(100, 10, input_length=max_length
          trainable=False)) # Notice a change
      new_model.add(Flatten())
      new_model.add(Dense(1, activation='sigmoid'))
      
    5. Compile a new model:

      new_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])