Search code examples
pythontensorflowmachine-learningkerashuggingface-transformers

TypeError when trying to apply custom loss in a multilabel classification problem


I am trying to solve a multilabel text classification problem using BERT from huggingface transformers library. The model is defined as follows:

def create_model(encoder, nb_classes=3, lr=1e-5):

    # inputs
    input_ids = tf.keras.Input(shape=(512,), ragged=False,
                               dtype=tf.int32, name='input_ids')
    input_attention_mask = tf.keras.Input(shape=(512,), ragged=False,
                                          dtype=tf.int32, name='attention_mask')
    # transformer
    output = encoder({'input_ids': input_ids, 
                      'attention_mask': input_attention_mask})[0]
    Y = tf.keras.layers.BatchNormalization()(output)
    Y = tf.keras.layers.Dense(nb_classes, activation='sigmoid')(Y)

    # compilation
    model = tf.keras.Model(inputs=[input_ids, input_attention_mask], 
                           outputs=[Y])
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    # losses
    # loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    # loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)

    model.compile(optimizer=optimizer, 
                  loss=multilabel_loss, metrics=['acc'])
    model.summary()
    return model

As you can see, I tried to use tf.keras.losses, but it did not work (throwing AttributeError: 'Tensor' object has no attribute 'nested_row_splits'), so I defined a simple cross entropy by hand:

def multilabel_loss(y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)
    cross_entropy = -tf.reduce_sum((y_true*tf.math.log(y_pred + 1e-8) + (1 - y_true) * tf.math.log(1 - y_pred + 1e-8)),
                                   name='xentropy')
    return cross_entropy

The model is created with strategy.scope() as shown below, using 'distil-bert-uncased' as a checkpoint:

with strategy.scope():
    encoder = TFAutoModelForSequenceClassification.from_pretrained(checkpoint)
    #encoder = TFRobertaForSequenceClassification.from_pretrained(checkpoint)
    model = create_model(encoder)

The labels are binary arrays:

163350    [0, 0, 1]
118940    [0, 0, 1]
65243     [0, 0, 1]
30011     [0, 0, 1]
189713    [0, 1, 0]

They are combined with tokenized texts into a tf.dataset in a next function:

def tf_text_data_prep(df):
    """
    input: takes pandas dataframe
    output: returns tokenized tf.Dataset
    """
    hugging_ds = Dataset.from_pandas(df)
    tokenized_ds = hugging_ds.map(
                      tokenize_function,
                      batched=True,
                      num_proc=strategy.num_replicas_in_sync,
                      remove_columns=["Text", '__index_level_0__'],
                      load_from_cache_file=True 
                      )
    
    # Convert to tensorflow
    tf_dataset = tokenized_ds.with_format("tensorflow")
    features = {x: tf_dataset[x].to_tensor() for x in tokenizer.model_input_names}
    tf_data = tf.data.Dataset.from_tensor_slices((features, tf_dataset["label"]))
    return tf_data

The problem is when I launch the training, I get the error:

TypeError                                 Traceback (most recent call last)
<ipython-input-62-720b4634d50e> in <module>()
----> 1 get_ipython().run_cell_magic('time', '', 'steps_per_epoch = int(BUFFER_SIZE // BATCH_SIZE)\nprint(\n    f"Model Params:\\nbatch_size: {BATCH_SIZE}\\nEpochs: {EPOCHS}\\n"\n    f"Step p. Epoch: {steps_per_epoch}\\n"\n    f"Initial Learning rate: {INITAL_LEARNING_RATE}"\n)\nhistory = model.fit(\n    train_ds,\n    validation_data=val_ds,\n    batch_size=BATCH_SIZE,\n    epochs=EPOCHS,\n    callbacks=callbacks,\n    verbose=1,\n)')

12 frames
/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
   2115             magic_arg_s = self.var_expand(line, stack_depth)
   2116             with self.builtin_trap:
-> 2117                 result = fn(magic_arg_s, cell)
   2118             return result
   2119 

<decorator-gen-53> in time(self, line, cell, local_ns)

/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
    186     # but it's overkill for just that one bit of state.
    187     def magic_deco(arg):
--> 188         call = lambda f, *a, **k: f(*a, **k)
    189 
    190         if callable(arg):

/usr/local/lib/python3.7/dist-packages/IPython/core/magics/execution.py in time(self, line, cell, local_ns)
   1191         else:
   1192             st = clock2()
-> 1193             exec(code, glob, local_ns)
   1194             end = clock2()
   1195             out = None

<timed exec> in <module>()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1176                 _r=1):
   1177               callbacks.on_train_batch_begin(step)
-> 1178               tmp_logs = self.train_function(iterator)
   1179               if data_handler.should_sync:
   1180                 context.async_wait()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    887 
    888       with OptionalXlaContext(self._jit_compile):
--> 889         result = self._call(*args, **kwds)
    890 
    891       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    762     self._concrete_stateful_fn = (
    763         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 764             *args, **kwds))
    765 
    766     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3048       args, kwargs = None, None
   3049     with self._lock:
-> 3050       graph_function, _ = self._maybe_define_function(args, kwargs)
   3051     return graph_function
   3052 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3442 
   3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args, kwargs)
   3445           self._function_cache.primary[cache_key] = graph_function
   3446 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3287             arg_names=arg_names,
   3288             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3289             capture_by_value=self._capture_by_value),
   3290         self._function_attributes,
   3291         function_spec=self.function_spec,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    997         _, original_func = tf_decorator.unwrap(python_func)
    998 
--> 999       func_outputs = python_func(*func_args, **func_kwargs)
   1000 
   1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    670         # the function a weak reference to itself to avoid a reference cycle.
    671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    673         return out
    674 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    984           except Exception as e:  # pylint:disable=broad-except
    985             if hasattr(e, "ag_error_metadata"):
--> 986               raise e.ag_error_metadata.to_exception(e)
    987             else:
    988               raise

TypeError: in user code:

    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:850 train_function  *
        return step_function(self, iterator)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:840 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1285 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2833 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3608 _call_for_each_replica
        return fn(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:833 run_step  **
        outputs = model.train_step(data)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:795 train_step
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/compile_utils.py:460 update_state
        metric_obj.update_state(y_t, y_p, sample_weight=mask)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:86 decorated
        update_op = update_state_fn(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py:177 update_state_fn
        return ag_update_state(*args, **kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py:659 update_state  **
        [y_true, y_pred], sample_weight)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:546 ragged_assert_compatible_and_get_flat_values
        raise TypeError('One of the inputs does not have acceptable types.')

    TypeError: One of the inputs does not have acceptable types.

This same approach worked for ordinary binary classification, but not for multilabel. I'd appreciate any help regarding the error or the approach in general.


Solution

  • The issue is that you are using TFAutoModelForSequenceClassification i.e. ForSequenceClassification and if you were to see it's summary you will find that it returns a Dense output and hence it is not an encoder as you want it.

    encoder = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
    encoder.summary()
    '''
    Result:
    All model checkpoint layers were used when initializing TFBertForSequenceClassification.
    
    Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier']
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    Model: "tf_bert_for_sequence_classification_2"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    bert (TFBertMainLayer)       multiple                  109482240 
    _________________________________________________________________
    dropout_187 (Dropout)        multiple                  0         
    _________________________________________________________________
    classifier (Dense)           multiple                  1538      
    =================================================================
    Total params: 109,483,778
    Trainable params: 109,483,778
    Non-trainable params: 0
    '''
    

    But you want to use it as encoder and hence you would have to do it like this

    from transformers import TFBertModel
    encoder = TFBertModel.from_pretrained('bert-base-uncased')
    model = create_model(encoder)
    
    encoder.summary()
    '''
    Model: "tf_bert_model_2"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    bert (TFBertMainLayer)       multiple                  109482240 
    =================================================================
    Total params: 109,482,240
    Trainable params: 109,482,240
    Non-trainable params: 0
    '''
    

    As you can see now you have your encoder as the output from the Bert. Now the below line in the create_model makes sense but it will give an error cause of below line in create_model function

    output = encoder({'input_ids': input_ids, 
                          'attention_mask': input_attention_mask})[0]
    

    This is cause the output at the 0 index is of shape (batch_size, token_length, embedding) but we want the value of [CLS] token which should be (batch_size, embedding) and that is at the 1 index so we have to update to below line:

    output = encoder({'input_ids': input_ids, 
                          'attention_mask': input_attention_mask})[1]
    

    Also, as of now you are fixing the input_shape to 512 but we should specify the value to None so that we can have variable length input as below

    input_ids = tf.keras.Input(shape=(None,), ragged=False, dtype=tf.int32, name='input_ids')
        input_attention_mask = tf.keras.Input(shape=(None,), ragged=False, dtype=tf.int32, name='attention_mask')
    

    After doing all these change below is the result of a sample run.

    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    encoder = TFBertModel.from_pretrained('bert-base-uncased')
    model = create_model(encoder)
    
    
    inputs = tokenizer('hello world', return_tensors='tf')
    model.predict((inputs['input_ids'], inputs['attention_mask']))
    
    '''
    Results:
    array([[0.7867866 , 0.65974414, 0.45628983]], dtype=float32)
    '''