Search code examples
pythonkeraslstmtensorflow-hub

Add LSTM layers after tensorflow-hub pretrained model


I'm working on text classification using Tensorflow-hub pretrained Word2vec model. And I'm seeking for adding an LSTM layer to the keras model. For that, I used the following code:

model = tf.keras.models.Sequential()
model.add(hub.KerasLayer(hub.load('https://tfhub.dev/google/Wiki-words-250/2'), 
                        input_shape=[], 
                        dtype=tf.string, 
                        trainable=True))

After adding an LSTM layer:

model.add(tf.keras.layers.LSTM(32))

It shows me the follwing error:

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name)
    174       ndim = x.shape.ndims
    175       if ndim != spec.ndim:
--> 176         raise ValueError('Input ' + str(input_index) + ' of layer ' +
    177                          layer_name + ' is incompatible with the layer: '
    178                          'expected ndim=' + str(spec.ndim) + ', found ndim=' +

ValueError: Input 0 of layer lstm_0 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 250]

Any help is appreciable.


Solution

  • You can reshape the output of the hub.KerasLayer:

    model.add(hub.KerasLayer(hub.load('https://tfhub.dev/google/Wiki-words-250/2'), 
                            input_shape=[], 
                            dtype=tf.string, 
                            trainable=True))
    
    model.add(tf.keras.layers.Reshape((250, 1)))
    model.add(tf.keras.layers.LSTM(32))
    
    model.summary()
    
    Layer (type)                 Output Shape              Param #   
    =================================================================
    keras_layer_4 (KerasLayer)   (None, 250)               252343750 
    _________________________________________________________________
    reshape_2 (Reshape)          (None, 250, 1)            0         
    _________________________________________________________________
    lstm_2 (LSTM)                (None, 32)                4352      
    =================================================================
    Total params: 252,348,102
    Trainable params: 252,348,102
    Non-trainable params: 0