Search code examples
tensorflowkerasdeep-learninghuggingface-transformersbert-language-model

How to set output_shape of BERT preprocessing layer from tensorflow hub?


I am building a simple BERT model for text classification, using the tensorflow hub.

import tensorflow as tf
import tensorflow_hub as tf_hub

bert_preprocess = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
bert_encoder = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")


text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
encoded_input = bert_encoder(preprocessed_text)

l1 = tf.keras.layers.Dropout(0.3, name="dropout1")(encoded_input['pooled_output'])
l2 = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l1)

model = tf.keras.Model(inputs=[text_input], outputs = [l2])

model.summary()

Upon analyzing the output of the bert_preprocess step, I noticed that they are arrays of length 128. My texts are much shorter on average than 128 tokens and as such, my intention would be to decrease this length parameter, so that the preprocessing yields, say, arrays of length 40 only. However, I cannot figure out how to pass this max_length or output_shape parameter to the bert_preprocess.

Printed model summary:

__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 text (InputLayer)              [(None,)]            0           []                               
                                                                                                  
 keras_layer_16 (KerasLayer)    {'input_word_ids':   0           ['text[0][0]']                   
                                (None, 128),                                                      
                                 'input_type_ids':                                                
                                (None, 128),                                                      
                                 'input_mask': (Non                                               
                                e, 128)}                                                          
                                                                                                  
 keras_layer_17 (KerasLayer)    {'sequence_output':  109482241   ['keras_layer_16[0][0]',         
                                 (None, 128, 768),                'keras_layer_16[0][1]',         
                                 'default': (None,                'keras_layer_16[0][2]']         
                                768),                                                             
                                 'encoder_outputs':                                               
                                 [(None, 128, 768),                                               
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
                                 (None, 128, 768),                                                
...
Total params: 109,483,010
Trainable params: 769
Non-trainable params: 109,482,241

Checking the documentation, I found there is a output_shape argument for tf_hub.KerasLayer, so I tried passing the following arguments:

bert_preprocess = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3", output_shape=(64,))
bert_preprocess = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3", output_shape=[64])

However, in both of these cases, the following line throws an error:

bert_preprocess(["we have a very sunny day today don't you think so?"])

Error:

ValueError                                Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_23952\4048288771.py in <module>
----> 1 bert_preprocess("we have a very sunny day today don't you think so?")

~\AppData\Roaming\Python\Python37\site-packages\keras\utils\traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

c:\Users\username\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_hub\keras_layer.py in call(self, inputs, training)
    237       result = smart_cond.smart_cond(training,
    238                                      lambda: f(training=True),
--> 239                                      lambda: f(training=False))
    240 
    241     # Unwrap dicts returned by signatures.

c:\Users\username\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_hub\keras_layer.py in <lambda>()
    237       result = smart_cond.smart_cond(training,
    238                                      lambda: f(training=True),
--> 239                                      lambda: f(training=False))
    240 
    241     # Unwrap dicts returned by signatures.
...
  Keyword arguments: {}

Call arguments received:
  • inputs="we have a very sunny day today don't you think so?"
  • training=False

Solution

  • You need to go lower levels in order to achieve this. Your goal was shown in the page of preprocess layer, however, not properly introduced.

    You can wrap your intention into a custom TF layer:

    class ModifiedBertPreprocess(tf.keras.layers.Layer):
        def __init__(self, max_len):
            super(ModifiedBertPreprocess, self).__init__()
            
            preprocessor = tf_hub.load(
                        "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
            
            self.tokenizer = tf_hub.KerasLayer(preprocessor.tokenize, name="tokenizer")
            
            self.prep_layer = tf_hub.KerasLayer(
                                 preprocessor.bert_pack_inputs,
                                 arguments={"seq_length":max_len})
            
        def call(self, inputs, training):
            tokenized = [self.tokenizer(seq) for seq in inputs]
            return self.prep_layer(tokenized)
    

    Basically, you will tokenize and prepare your inputs by yourself. Preprocessor has a method named bert_pack_inputs which will let you the specify max_len of the inputs.

    For some reason, self.tokenizer expects the inputs in a list format. Mostly likely this will allow it to accept multiple inputs.

    Your model should look like this:

    bert_encoder = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")
    
    text_input = [tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')]
    
    bert_seq_changed = ModifiedBertPreprocess(max_len=40)
    
    encoder_inputs = bert_seq_changed(text_input)
    
    encoded_input = bert_encoder(encoder_inputs)
    
    l1 = tf.keras.layers.Dropout(0.3, name="dropout1")(encoded_input['pooled_output'])
    l2 = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l1)
    
    model = tf.keras.Model(inputs=[text_input], outputs = [l2])
    

    Note that text_input layer is now inside in a list as self.tokenizer's input signatures expects a list.

    Here's the model summary:

    Model: "model"
    __________________________________________________________________________________________________
     Layer (type)                   Output Shape         Param #     Connected to                     
    ==================================================================================================
     text (InputLayer)              [(None,)]            0           []                               
                                                                                                      
     modified_bert_preprocess (Modi  {'input_type_ids':   0          ['text[0][0]']                   
     fiedBertPreprocess)            (None, 40),                                                       
                                     'input_word_ids':                                                
                                    (None, 40),                                                       
                                     'input_mask': (Non                                               
                                    e, 40)}                                                           
                                                                                                      
     keras_layer (KerasLayer)       {'encoder_outputs':  109482241   ['modified_bert_preprocess[0][0]'
                                     [(None, 40, 768),               , 'modified_bert_preprocess[0][1]
                                     (None, 40, 768),                ',                               
                                     (None, 40, 768),                 'modified_bert_preprocess[0][2]'
                                     (None, 40, 768),                ]                                
                                     (None, 40, 768),                                                 
                                     (None, 40, 768),                                                 
                                     (None, 40, 768),                                                 
                                     (None, 40, 768),                                                 
                                     (None, 40, 768),                                                 
                                     (None, 40, 768),                                                 
                                     (None, 40, 768),                                                 
                                     (None, 40, 768)],                                                
                                     'default': (None,                                                
                                    768),                                                             
                                     'pooled_output': (                                               
                                    None, 768),                                                       
                                     'sequence_output':                                               
                                     (None, 40, 768)}                                                 
                                                                                                      
     dropout1 (Dropout)             (None, 768)          0           ['keras_layer[0][13]']           
                                                                                                      
     output (Dense)                 (None, 1)            769         ['dropout1[0][0]']               
                                                                                                      
    ==================================================================================================
    Total params: 109,483,010
    Trainable params: 769
    Non-trainable params: 109,482,241
    

    When calling the custom preprocessing layer:

    bert_seq_changed([tf.convert_to_tensor(["we have a very sunny day today don't you think so?"], dtype=tf.string)])
    

    Notice, the inputs should be in a list. Calling the model can be done with both ways:

    model([tf.convert_to_tensor(["we have a very sunny day today don't you think so?"], dtype=tf.string)])
    

    or

    model(tf.convert_to_tensor(["we have a very sunny day today don't you think so?"], dtype=tf.string))